diff --git a/.circleci/config.yml b/.circleci/config.yml deleted file mode 100644 index 4db6685..0000000 --- a/.circleci/config.yml +++ /dev/null @@ -1,118 +0,0 @@ -version: 2 -jobs: - test: - docker: - - image: circleci/python:3.10 - - working_directory: ~/repo - - steps: - - checkout - - # Download and cache dependencies - - restore_cache: - keys: - - v2-dependencies-{{ checksum "setup/requirements.txt" }} - # fallback to using the latest cache if no exact match is found - - v2-dependencies- - - - run: - name: install dependencies - command: | - python -m venv .venv - . .venv/bin/activate - pip install -r setup/requirements.txt - - - save_cache: - paths: - - .venv - key: v2-dependencies-{{ checksum "setup/requirements.txt" }} - - run: - name: run tests - command: | - . .venv/bin/activate - cp fvserver/example_settings.py fvserver/settings.py - # python manage.py test - python manage.py migrate - - run: - name: run linting - command: | - . .venv/bin/activate - black --check . - - - store_artifacts: - path: test-reports - destination: test-reports - build_latest: - docker: - - image: docker:18.06.1-ce-git - - steps: - - checkout - - setup_remote_docker - - run: docker build -t macadmins/crypt-server:latest . - - run: docker login -u $DOCKER_USER -p $DOCKER_PASS - - run: docker push macadmins/crypt-server:latest - - run: apk add python2 py2-pip - - run: pip install requests - - run: python remote_build.py latest - - build_branch: - docker: - - image: docker:18.06.1-ce-git - - steps: - - checkout - - setup_remote_docker: - docker_layer_caching: true - - run: docker build -t macadmins/crypt-server:$CIRCLE_BRANCH . - - run: docker login -u $DOCKER_USER -p $DOCKER_PASS - - run: docker push macadmins/crypt-server:$CIRCLE_BRANCH - - run: apk update - - run: apk add python2 py2-pip - - run: pip install requests - - run: python remote_build.py $CIRCLE_BRANCH - - build_tag: - docker: - - image: docker:18.06.1-ce-git - - steps: - - checkout - - setup_remote_docker: - docker_layer_caching: true - - run: docker build -t macadmins/crypt-server:$CIRCLE_TAG . - - run: docker login -u $DOCKER_USER -p $DOCKER_PASS - - run: docker push macadmins/crypt-server:$CIRCLE_TAG - - run: apk add python2 py2-pip - - run: pip install requests - - run: python remote_build.py $CIRCLE_TAG - -workflows: - version: 2 - build_and_test: - jobs: - - test: - filters: - tags: - only: /.*/ - - build_latest: - requires: - - test - filters: - branches: - only: master - - build_branch: - requires: - - test - filters: - branches: - ignore: master - - build_tag: - requires: - - test - filters: - tags: - only: /.*/ - branches: - ignore: /.*/ diff --git a/.flake8 b/.flake8 deleted file mode 100644 index e348728..0000000 --- a/.flake8 +++ /dev/null @@ -1,5 +0,0 @@ -[flake8] -ignore = F401,F405,E402,F403 -max-line-length = 100 -max-complexity = 100 -exclude = venv/*,*/migrations/* diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..1cee315 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,135 @@ +name: Release + +on: + workflow_dispatch: + inputs: + version: + description: 'Release version (e.g., v1.0.0)' + required: true + type: string + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + include: + - goos: linux + goarch: amd64 + - goos: linux + goarch: arm64 + - goos: darwin + goarch: amd64 + - goos: darwin + goarch: arm64 + + steps: + - uses: actions/checkout@v4 + with: + ref: master + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: "1.22" + + - name: Build crypt-server + env: + GOOS: ${{ matrix.goos }} + GOARCH: ${{ matrix.goarch }} + CGO_ENABLED: 0 + run: | + go build -ldflags="-s -w -X main.Version=${{ inputs.version }}" -o crypt-server ./cmd/crypt-server + + - name: Build cryptctl + env: + GOOS: ${{ matrix.goos }} + GOARCH: ${{ matrix.goarch }} + CGO_ENABLED: 0 + run: | + go build -ldflags="-s -w -X main.Version=${{ inputs.version }}" -o cryptctl ./cmd/cryptctl + + - name: Create archive + run: | + mkdir -p dist + cp -r web dist/ + cp crypt-server cryptctl dist/ + cd dist + zip -r ../crypt-server-${{ inputs.version }}-${{ matrix.goos }}-${{ matrix.goarch }}.zip . + + - name: Upload artifact + uses: actions/upload-artifact@v4 + with: + name: crypt-server-${{ matrix.goos }}-${{ matrix.goarch }} + path: crypt-server-${{ inputs.version }}-${{ matrix.goos }}-${{ matrix.goarch }}.zip + + release: + needs: build + runs-on: ubuntu-latest + permissions: + contents: write + + steps: + - uses: actions/checkout@v4 + with: + ref: master + + - name: Download all artifacts + uses: actions/download-artifact@v4 + with: + path: artifacts + + - name: Create Release + uses: softprops/action-gh-release@v2 + with: + tag_name: ${{ inputs.version }} + name: ${{ inputs.version }} + draft: false + prerelease: false + files: artifacts/**/*.zip + generate_release_notes: true + + docker: + needs: release + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + + steps: + - uses: actions/checkout@v4 + with: + ref: master + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata for Docker + id: meta + uses: docker/metadata-action@v5 + with: + images: ghcr.io/${{ github.repository }} + tags: | + type=raw,value=${{ inputs.version }} + type=raw,value=latest + + - name: Build and push Docker image + uses: docker/build-push-action@v5 + with: + context: . + platforms: linux/amd64,linux/arm64 + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + cache-from: type=gha + cache-to: type=gha,mode=max diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..db4a8f6 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,76 @@ +name: Tests + +on: + push: + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: "1.22" + + - name: Run unit tests + run: go test -v ./... + + integration-sqlite: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: "1.22" + + - name: Build cryptctl + run: go build -o cryptctl ./cmd/cryptctl/ + + - name: Generate encryption key + run: ./cryptctl gen-key > /tmp/test_key.txt + + - name: Run SQLite integration tests + run: ./cryptctl integration-test -db sqlite -key-file /tmp/test_key.txt + + integration-postgres: + runs-on: ubuntu-latest + services: + postgres: + image: postgres:15 + env: + POSTGRES_USER: crypt + POSTGRES_PASSWORD: crypt_test_password + POSTGRES_DB: crypt_test + ports: + - 5432:5432 + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + steps: + - uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: "1.22" + + - name: Build cryptctl + run: go build -o cryptctl ./cmd/cryptctl/ + + - name: Generate encryption key + run: ./cryptctl gen-key > /tmp/test_key.txt + + - name: Run PostgreSQL integration tests + run: | + ./cryptctl integration-test \ + -db postgres \ + -db-url "postgres://crypt:crypt_test_password@localhost:5432/crypt_test?sslmode=disable" \ + -key-file /tmp/test_key.txt diff --git a/.gitignore b/.gitignore index 6b3a2c9..471c094 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,16 @@ keyset *.db .vscode +.gocache/* + +/cryptctl +/crypt-server +.field-encryption-key + +# SAML config (contains environment-specific paths) +saml-config.yaml +okta-metadata.xml +sp.crt +sp.key + +.claude/settings.local.json \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100644 index 349f92e..0000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,10 +0,0 @@ -repos: -- repo: https://github.com/adamchainz/django-upgrade - rev: "1.10.0" - hooks: - - id: django-upgrade - args: [--target-version, "4.1"] -- repo: https://github.com/psf/black - rev: 21.12b0 - hooks: - - id: black \ No newline at end of file diff --git a/AGENTS.md b/AGENTS.md new file mode 120000 index 0000000..681311e --- /dev/null +++ b/AGENTS.md @@ -0,0 +1 @@ +CLAUDE.md \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..c927421 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,5 @@ +- Always write unit tests for all code you write +- Unit tests should use Testify +- All user input should be protected from CSRF +- If you are editing code that is configured by end users, either via environment variables, config file or flags, ensure that you update documentation accordingly +- Use `go fmt` to format your code diff --git a/Dockerfile b/Dockerfile index 6f0be9c..e3dbc72 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,55 +1,37 @@ -FROM python:3.10.11-alpine3.16 +# Build stage +FROM golang:1.22-alpine AS builder -LABEL maintainer="graham@grahamgilbert.com" +RUN apk add --no-cache git -ENV APP_DIR /home/docker/crypt -ENV DEBUG false -ENV LANG en -ENV TZ Etc/UTC -ENV LC_ALL en_US.UTF-8 +WORKDIR /app +# Copy go mod files first for better caching +COPY go.mod go.sum ./ +RUN go mod download +# Copy source code +COPY . . -RUN set -ex \ - && apk add --no-cache --virtual .build-deps \ - gcc \ - git \ - openssl-dev \ - build-base \ - libffi-dev \ - libc-dev \ - musl-dev \ - linux-headers \ - pcre-dev \ - postgresql-dev \ - xmlsec-dev \ - tzdata \ - postgresql-libs \ - libpq +# Build the binary +RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o crypt-server ./cmd/crypt-server -COPY setup/requirements.txt /tmp/requirements.txt +# Runtime stage +FROM alpine:3.19 -RUN set -ex \ - && LIBRARY_PATH=/lib:/usr/lib /bin/sh -c "pip install --no-cache-dir -r /tmp/requirements.txt" \ - && rm /tmp/requirements.txt +RUN apk add --no-cache ca-certificates tzdata -COPY / $APP_DIR -COPY docker/settings.py $APP_DIR/fvserver/ -COPY docker/settings_import.py $APP_DIR/fvserver/ -COPY docker/gunicorn_config.py $APP_DIR/ -COPY docker/django/management/ $APP_DIR/server/management/ -COPY docker/run.sh /run.sh +WORKDIR /app -RUN chmod +x /run.sh \ - && mkdir -p /home/app \ - && ln -s ${APP_DIR} /home/app/crypt +# Copy binary from builder +COPY --from=builder /app/crypt-server . -WORKDIR ${APP_DIR} -# don't use this key anywhere else, this is just for collectstatic to run -RUN export FIELD_ENCRYPTION_KEY="jKAv1Sde8m6jCYFnmps0iXkUfAilweNVjbvoebBrDwg="; python manage.py collectstatic --noinput; export FIELD_ENCRYPTION_KEY="" +# Copy web assets +COPY --from=builder /app/web ./web -EXPOSE 8000 +# Create non-root user +RUN adduser -D -u 1000 crypt +USER crypt -VOLUME $APP_DIR/keyset +EXPOSE 8080 -CMD ["/run.sh"] +ENTRYPOINT ["/app/crypt-server"] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..e30877f --- /dev/null +++ b/Makefile @@ -0,0 +1,36 @@ +VERSION := $(shell cat VERSION) +LDFLAGS := -ldflags "-X crypt-server/internal/app.Version=$(VERSION)" + +.PHONY: build cryptctl clean test run-sqlite run-sqlite-saml + +build: + go build $(LDFLAGS) -o crypt-server ./cmd/crypt-server + +cryptctl: + go build -o cryptctl ./cmd/cryptctl + +test: + go test ./... + +clean: + rm -f crypt-server cryptctl + +run-sqlite: build cryptctl + @echo "Starting crypt-server with SQLite..." + @echo "Server will be available at http://localhost:8080" + @test -f .field-encryption-key || ./cryptctl gen-key > .field-encryption-key + SQLITE_PATH=./crypt.db \ + FIELD_ENCRYPTION_KEY=$$(cat .field-encryption-key) \ + SESSION_KEY=$${SESSION_KEY:-$$(./cryptctl gen-key)} \ + ./crypt-server + +run-sqlite-saml: build cryptctl + @echo "Starting crypt-server with SQLite and SAML..." + @echo "Server will be available at http://localhost:8080" + @test -f .field-encryption-key || ./cryptctl gen-key > .field-encryption-key + @test -f saml-config.yaml || (echo "Error: saml-config.yaml not found" && exit 1) + SQLITE_PATH=./crypt.db \ + FIELD_ENCRYPTION_KEY=$$(cat .field-encryption-key) \ + SESSION_KEY=$${SESSION_KEY:-$$(./cryptctl gen-key)} \ + SAML_CONFIG_FILE=./saml-config.yaml \ + ./crypt-server diff --git a/NOTICE b/NOTICE index 693ac79..80f1608 100644 --- a/NOTICE +++ b/NOTICE @@ -1,4 +1,4 @@ -Copyright 2012-2016 Graham Gilbert +Copyright 2012-2026 Graham Gilbert Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at diff --git a/README.md b/README.md index eaae4c1..7d78273 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,6 @@ # Crypt-Server -**[Crypt][1]** is a tool for securely storing secrets such as FileVault 2 recovery keys. It is made up of a client app, and a Django web app for storing the keys. - -This Docker image contains the fully configured Crypt Django web app. A default admin user has been preconfigured, use admin/password to login. -If you intend on using the server for anything semi-serious it is a good idea to change the password or add a new admin user and delete the default one. +**[Crypt][1]** is a tool for securely storing secrets such as FileVault 2 recovery keys. It is made up of a client app, and a web app for storing the keys. ## Features @@ -14,33 +11,152 @@ If you intend on using the server for anything semi-serious it is a good idea to [1]: https://github.com/grahamgilbert/Crypt +## Migration from Django + +### Step 1: Export data from Django + +Export your Django database to a JSON fixture: + +```bash +# If running Django directly: +cd /path/to/legacy/crypt-server +./manage.py dumpdata > legacy.json + +# If running Django in Docker: +docker exec python manage.py dumpdata > legacy.json +``` + +### Step 2: Generate a new encryption key + +Generate a new AES-GCM encryption key for the Go backend: + +```bash +./cryptctl gen-key > new-field-encryption-key.txt +``` + +### Step 3: Convert the fixture + +Convert the Django JSON fixture into the new format. This re-encrypts all secrets from Django's Fernet encryption to the new AES-GCM format: + +```bash +./cryptctl import-fixture \ + -input legacy.json \ + -output migration-export.json \ + -legacy-key-file legacy-field-encryption-key.txt \ + -new-key-file new-field-encryption-key.txt \ + -password-map password-map.csv +``` + +The optional password map CSV allows you to set passwords for users who should have local login enabled. Any users not in this map will be configured for SAML authentication only. The CSV should have the following format (including header row): + +```csv +username_or_email,password,must_reset_password +admin@example.com,Str0ng!Passw0rd,false +``` + +Users not in the password map will be configured for SAML authentication only. + +### Step 4: Import into the new server + +Import the converted fixture into the Go server. **The database must be empty** (no existing computers, secrets, requests, or users). + +First, set the required environment variables: + +```bash +export FIELD_ENCRYPTION_KEY=$(cat new-field-encryption-key.txt) +export SESSION_KEY=$(./cryptctl gen-key) +``` + +Then run the import: + +```bash +./crypt-server -import-fixture migration-export.json +``` + +The import will: + +- Verify the database is empty (fails if any data exists) +- Import all computers with their original IDs +- Import all secrets (already re-encrypted with the new key) +- Import all users with their authentication settings +- Import all requests with their approval status + +After import completes, start the server normally (the environment variables are already set): + +```bash +./crypt-server +``` + ## Installation instructions -It is recommended that you use [Docker](https://github.com/grahamgilbert/Crypt-Server/blob/master/docs/Docker.md) to run this, but if you wish to run directly on a host, installation instructions are over in the [docs directory](https://github.com/grahamgilbert/Crypt-Server/blob/master/docs/Installation_on_Ubuntu_1404.md) +It is recommended that you use [Docker](docs/Docker.md) to run this. See the Docker documentation for complete setup instructions. -### Migrating from versions earlier than Crypt 3.0 +### Migrating from the Django version -Crypt 3 changed it's encryption backend, so when migrating from versions earlier than Crypt 3.0, you should first run Crypt 3.2.0 to perform the migration, and then upgrade to the latest version. The last version to support legacy migrations was Crypt 3.2. +If you are migrating from the Django version of Crypt Server, follow the "Migration from Django" steps above. If you are running a version earlier than Crypt 3.0, you should first upgrade to Django Crypt 3.2.0 to migrate from the legacy encryption format, then follow the migration steps to move to the Go version. ## Settings -All settings that would be entered into `settings.py` can also be passed into the Docker container as environment variables. +All settings are configured via environment variables. + +### Required + +- `FIELD_ENCRYPTION_KEY` - Base64-encoded 32-byte key for encrypting secrets. Generate with `./cryptctl gen-key`. + +- `SESSION_KEY` - A random string (at least 32 bytes) used to sign session cookies. Generate with `./cryptctl gen-key`. + +### Database (one required) + +- `DATABASE_URL` - PostgreSQL connection string (e.g., `postgres://user:pass@host:5432/dbname`). Mutually exclusive with `SQLITE_PATH`. + +- `SQLITE_PATH` - SQLite database file path. Must be a file path (not `:memory:`). Mutually exclusive with `DATABASE_URL`. + +### Optional + +- `SESSION_COOKIE_SECURE` - Set to `true` to mark session cookies as secure (recommended when using HTTPS). Default: `false`. + +- `SAML_CONFIG_FILE` - Path to a YAML file containing SAML configuration. See `docs/saml-config.sample.yaml` for all supported fields. + +- `APPROVE_OWN` - Allow users with approval permissions to approve their own key requests. Default: `false`. + +- `ALL_APPROVE` - Grant all users approval permissions when they log in. Default: `false`. + +- `ROTATE_VIEWED_SECRETS` - Instruct compatible clients (Crypt 3.2.0+) to rotate and re-escrow secrets after viewing. Default: `false`. + +## Database migrations + +The Go server applies embedded SQL migrations on startup and records applied versions in `schema_migrations`. + +Migration file naming: `NNN_description.sql` (for example, `002_add_requests.sql`). + +Flags: + +- `-validate-migrations` - Validate embedded migrations and exit. +- `-print-migrations` - Print embedded migrations and exit. +- `-migrations-driver` - Limit the validation/print target to `postgres` or `sqlite` (default: both). + +Example: -- `FIELD_ENCRYPTION_KEY` - The key to use when encrypting the secrets. This is required. +``` bash +./crypt-server -validate-migrations -migrations-driver=postgres +``` -- `SEND_EMAIL` - Crypt Server can send email notifcations when secrets are requested and approved. Set `SEND_EMAIL` to True, and set `HOST_NAME` to your server's host and URL scheme (e.g. `https://crypt.example.com`). For configuring your email settings, see the [Django documentation](https://docs.djangoproject.com/en/3.1/ref/settings/#std:setting-EMAIL_HOST). +## First admin creation -- `EMAIL_SENDER` - The email address to send emaiil notifications from when secrets are requests and approved. Ensure this is verified if you are using SES. Does nothing unless `SEND_EMAIIL` is True. +Create the initial admin user (only works when no users exist yet): -- `APPROVE_OWN` - By default, users with approval permissons can approve their own key requests. By setting this to False in settings.py (or by using the `APPROVE_OWN` environment variable with Docker), users cannot approve their own requests. +``` bash +./crypt-server -create-admin -username=admin -password='your-password' +``` -- `ALL_APPROVE` - By default, users need to be explicitly given approval permissions to approve key retrieval requests. By setting this to True in `settings.py`, all users are given this permission when they log in. +## Password reset -- `ROTATE_VIEWED_SECRETS` - With a compatible client (such as Crypt 3.2.0 and greater), Crypt Server can instruct the client to rotate the secret and re-escrow it when the secret has been viewed. Enable by setting this to `True` or by using `ROTATE_VIEWED_SECRETS` and setting to `true`. +Reset a user's password from the command line: -- `HOST_NAME` - Set the host name of your instance - required if you do not have control over the load balancer or proxy in front of your Crypt server (see [the Django documentation](https://docs.djangoproject.com/en/4.1/ref/settings/#csrf-trusted-origins)). +``` bash +./crypt-server -reset-password -username=admin -password='new-password' +``` -- `CSRF_TRUSTED_ORIGINS` - Is a list of trusted origins expected to make requests to your Crypt instance, normally this is the hostname ## Screenshots Main Page: diff --git a/VERSION b/VERSION new file mode 100644 index 0000000..bbf753a --- /dev/null +++ b/VERSION @@ -0,0 +1 @@ +4.0.0.1 diff --git a/cmd/crypt-server/admin.go b/cmd/crypt-server/admin.go new file mode 100644 index 0000000..aeace17 --- /dev/null +++ b/cmd/crypt-server/admin.go @@ -0,0 +1,56 @@ +package main + +import ( + "fmt" + "strings" + + "crypt-server/internal/app" + "crypt-server/internal/store" +) + +func createFirstAdmin(dataStore store.Store, username, password string) error { + if strings.TrimSpace(username) == "" { + return fmt.Errorf("admin username is required") + } + if password == "" { + return fmt.Errorf("admin password is required") + } + users, err := dataStore.ListUsers() + if err != nil { + return fmt.Errorf("list users: %w", err) + } + if len(users) > 0 { + return fmt.Errorf("cannot create first admin: users already exist") + } + hash, err := app.HashPassword(password) + if err != nil { + return fmt.Errorf("hash password: %w", err) + } + _, err = dataStore.AddUser(username, hash, true, true, true, false, "local") + if err != nil { + return fmt.Errorf("create admin: %w", err) + } + return nil +} + +func resetUserPassword(dataStore store.Store, username, password string) error { + if strings.TrimSpace(username) == "" { + return fmt.Errorf("username is required") + } + if password == "" { + return fmt.Errorf("password is required") + } + user, err := dataStore.GetUserByUsername(username) + if err != nil { + return fmt.Errorf("get user: %w", err) + } + hash, err := app.HashPassword(password) + if err != nil { + return fmt.Errorf("hash password: %w", err) + } + _, err = dataStore.UpdateUserPassword(user.ID, hash, false) + if err != nil { + return fmt.Errorf("update password: %w", err) + } + return nil +} diff --git a/cmd/crypt-server/admin_test.go b/cmd/crypt-server/admin_test.go new file mode 100644 index 0000000..fc566f0 --- /dev/null +++ b/cmd/crypt-server/admin_test.go @@ -0,0 +1,97 @@ +package main + +import ( + "encoding/base64" + "testing" + + "crypt-server/internal/crypto" + "crypt-server/internal/migrate" + "crypt-server/internal/store" + "github.com/stretchr/testify/require" +) + +func TestCreateFirstAdmin(t *testing.T) { + dataStore := newTestSQLiteStore(t) + err := createFirstAdmin(dataStore, "admin", "secret") + require.NoError(t, err) + + user, err := dataStore.GetUserByUsername("admin") + require.NoError(t, err) + require.True(t, user.IsStaff) + require.True(t, user.CanApprove) + require.True(t, user.LocalLoginEnabled) + require.False(t, user.MustResetPassword) + require.Equal(t, "local", user.AuthSource) +} + +func TestCreateFirstAdminRejectsExistingUsers(t *testing.T) { + dataStore := newTestSQLiteStore(t) + _, err := dataStore.AddUser("existing", "hash", true, true, true, false, "local") + require.NoError(t, err) + + err = createFirstAdmin(dataStore, "admin", "secret") + require.Error(t, err) +} + +func TestCreateFirstAdminRequiresUsername(t *testing.T) { + dataStore := newTestSQLiteStore(t) + err := createFirstAdmin(dataStore, " ", "secret") + require.Error(t, err) +} + +func TestCreateFirstAdminRequiresPassword(t *testing.T) { + dataStore := newTestSQLiteStore(t) + err := createFirstAdmin(dataStore, "admin", "") + require.Error(t, err) +} + +func TestResetUserPassword(t *testing.T) { + dataStore := newTestSQLiteStore(t) + _, err := dataStore.AddUser("testuser", "oldhash", false, false, true, false, "local") + require.NoError(t, err) + + err = resetUserPassword(dataStore, "testuser", "newpassword") + require.NoError(t, err) + + user, err := dataStore.GetUserByUsername("testuser") + require.NoError(t, err) + require.NotEqual(t, "oldhash", user.PasswordHash) +} + +func TestResetUserPasswordRequiresUsername(t *testing.T) { + dataStore := newTestSQLiteStore(t) + err := resetUserPassword(dataStore, " ", "newpassword") + require.Error(t, err) +} + +func TestResetUserPasswordRequiresPassword(t *testing.T) { + dataStore := newTestSQLiteStore(t) + _, err := dataStore.AddUser("testuser", "oldhash", false, false, true, false, "local") + require.NoError(t, err) + + err = resetUserPassword(dataStore, "testuser", "") + require.Error(t, err) +} + +func TestResetUserPasswordUserNotFound(t *testing.T) { + dataStore := newTestSQLiteStore(t) + err := resetUserPassword(dataStore, "nonexistent", "newpassword") + require.Error(t, err) +} + +func newTestSQLiteStore(t *testing.T) *store.SQLiteStore { + t.Helper() + key := make([]byte, 32) + for i := range key { + key[i] = byte(i + 1) + } + encoded := base64.StdEncoding.EncodeToString(key) + codec, err := crypto.NewAesGcmCodecFromBase64Key(encoded) + require.NoError(t, err) + sqliteStore, err := store.NewSQLiteStore(t.TempDir()+"/crypt.db", codec) + require.NoError(t, err) + migrationFS, err := migrate.SubMigrationsFS(migrate.EmbeddedFS, "sqlite") + require.NoError(t, err) + require.NoError(t, migrate.Apply(sqliteStore.DB(), "sqlite", migrationFS)) + return sqliteStore +} diff --git a/cmd/crypt-server/config.go b/cmd/crypt-server/config.go new file mode 100644 index 0000000..f306a8d --- /dev/null +++ b/cmd/crypt-server/config.go @@ -0,0 +1,39 @@ +package main + +import ( + "fmt" + "os" + "strings" +) + +type databaseConfig struct { + driver string + dsn string +} + +func loadDatabaseConfig() (databaseConfig, error) { + postgresURL := strings.TrimSpace(os.Getenv("DATABASE_URL")) + sqlitePath := strings.TrimSpace(os.Getenv("SQLITE_PATH")) + + if postgresURL != "" && sqlitePath != "" { + return databaseConfig{}, fmt.Errorf("set only one of DATABASE_URL or SQLITE_PATH") + } + if postgresURL != "" { + return databaseConfig{driver: "postgres", dsn: postgresURL}, nil + } + if sqlitePath != "" { + if isSQLiteMemory(sqlitePath) { + return databaseConfig{}, fmt.Errorf("SQLITE_PATH must point to a file, not an in-memory database") + } + return databaseConfig{driver: "sqlite", dsn: sqlitePath}, nil + } + return databaseConfig{}, fmt.Errorf("DATABASE_URL or SQLITE_PATH is required") +} + +func isSQLiteMemory(path string) bool { + cleaned := strings.ToLower(strings.TrimSpace(path)) + if cleaned == ":memory:" { + return true + } + return strings.Contains(cleaned, "mode=memory") +} diff --git a/cmd/crypt-server/config_test.go b/cmd/crypt-server/config_test.go new file mode 100644 index 0000000..ac56cbe --- /dev/null +++ b/cmd/crypt-server/config_test.go @@ -0,0 +1,59 @@ +package main + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLoadDatabaseConfigPostgres(t *testing.T) { + t.Setenv("DATABASE_URL", "postgres://example") + t.Setenv("SQLITE_PATH", "") + + cfg, err := loadDatabaseConfig() + require.NoError(t, err) + require.Equal(t, "postgres", cfg.driver) + require.Equal(t, "postgres://example", cfg.dsn) +} + +func TestLoadDatabaseConfigSQLite(t *testing.T) { + t.Setenv("DATABASE_URL", "") + t.Setenv("SQLITE_PATH", "data/crypt.db") + + cfg, err := loadDatabaseConfig() + require.NoError(t, err) + require.Equal(t, "sqlite", cfg.driver) + require.Equal(t, "data/crypt.db", cfg.dsn) +} + +func TestLoadDatabaseConfigRejectsBoth(t *testing.T) { + t.Setenv("DATABASE_URL", "postgres://example") + t.Setenv("SQLITE_PATH", "data/crypt.db") + + _, err := loadDatabaseConfig() + require.Error(t, err) +} + +func TestLoadDatabaseConfigRejectsMissing(t *testing.T) { + t.Setenv("DATABASE_URL", "") + t.Setenv("SQLITE_PATH", "") + + _, err := loadDatabaseConfig() + require.Error(t, err) +} + +func TestLoadDatabaseConfigRejectsSQLiteMemory(t *testing.T) { + t.Setenv("DATABASE_URL", "") + t.Setenv("SQLITE_PATH", ":memory:") + + _, err := loadDatabaseConfig() + require.Error(t, err) +} + +func TestLoadDatabaseConfigRejectsSQLiteMemoryMode(t *testing.T) { + t.Setenv("DATABASE_URL", "") + t.Setenv("SQLITE_PATH", "file:crypt.db?mode=memory&cache=shared") + + _, err := loadDatabaseConfig() + require.Error(t, err) +} diff --git a/cmd/crypt-server/import.go b/cmd/crypt-server/import.go new file mode 100644 index 0000000..7fe4bb5 --- /dev/null +++ b/cmd/crypt-server/import.go @@ -0,0 +1,110 @@ +package main + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "time" + + "crypt-server/internal/fixture" + "crypt-server/internal/store" +) + +var ErrDatabaseNotEmpty = errors.New("database is not empty; import only allowed on empty database") + +func importFixture(st store.Store, fixturePath string) error { + // Check if database is empty first + isEmpty, err := st.IsEmpty() + if err != nil { + return fmt.Errorf("check database: %w", err) + } + if !isEmpty { + return ErrDatabaseNotEmpty + } + + // Read and parse the fixture file + data, err := os.ReadFile(fixturePath) + if err != nil { + return fmt.Errorf("read fixture file: %w", err) + } + + var migration fixture.MigrationOutput + if err := json.Unmarshal(data, &migration); err != nil { + return fmt.Errorf("parse fixture: %w", err) + } + + // Import computers first (secrets depend on them) + for _, c := range migration.Computers { + lastCheckin, err := parseDateTime(c.LastCheckin) + if err != nil { + return fmt.Errorf("parse computer %d last_checkin: %w", c.ID, err) + } + if err := st.ImportComputer(c.ID, c.Serial, c.Username, c.ComputerName, lastCheckin); err != nil { + return fmt.Errorf("import computer %d: %w", c.ID, err) + } + } + + // Import secrets (requests depend on them) + for _, s := range migration.Secrets { + dateEscrowed, err := parseDateTime(s.DateEscrowed) + if err != nil { + return fmt.Errorf("parse secret %d date_escrowed: %w", s.ID, err) + } + if err := st.ImportSecret(s.ID, s.ComputerID, s.SecretType, s.Secret, dateEscrowed, s.RotationRequired); err != nil { + return fmt.Errorf("import secret %d: %w", s.ID, err) + } + } + + // Import users + for _, u := range migration.Users { + if err := st.ImportUser(u.ID, u.Username, u.PasswordHash, u.IsStaff, u.CanApprove, u.LocalLoginEnabled, u.MustResetPassword, u.AuthSource); err != nil { + return fmt.Errorf("import user %d: %w", u.ID, err) + } + } + + // Import requests + for _, r := range migration.Requests { + dateRequested, err := parseDateTime(r.DateRequested) + if err != nil { + return fmt.Errorf("parse request %d date_requested: %w", r.ID, err) + } + var dateApproved *time.Time + if r.DateApproved != "" { + da, err := parseDateTime(r.DateApproved) + if err != nil { + return fmt.Errorf("parse request %d date_approved: %w", r.ID, err) + } + dateApproved = &da + } + if err := st.ImportRequest(r.ID, r.SecretID, r.RequestingUser, r.Approved, r.AuthUser, r.ReasonForRequest, r.ReasonForApproval, dateRequested, dateApproved, r.Current); err != nil { + return fmt.Errorf("import request %d: %w", r.ID, err) + } + } + + return nil +} + +// parseDateTime parses a datetime string from Django fixtures. +// Supports formats: "2006-01-02T15:04:05Z", "2006-01-02T15:04:05.000Z", "2006-01-02 15:04:05" +func parseDateTime(value string) (time.Time, error) { + if value == "" { + return time.Time{}, nil + } + + formats := []string{ + time.RFC3339, + "2006-01-02T15:04:05.000Z", + "2006-01-02T15:04:05Z", + "2006-01-02 15:04:05", + "2006-01-02T15:04:05", + } + + for _, format := range formats { + if t, err := time.Parse(format, value); err == nil { + return t, nil + } + } + + return time.Time{}, fmt.Errorf("unable to parse datetime: %s", value) +} diff --git a/cmd/crypt-server/import_test.go b/cmd/crypt-server/import_test.go new file mode 100644 index 0000000..a45f566 --- /dev/null +++ b/cmd/crypt-server/import_test.go @@ -0,0 +1,229 @@ +package main + +import ( + "os" + "path/filepath" + "testing" + "time" + + "crypt-server/internal/crypto" + "crypt-server/internal/migrate" + "crypt-server/internal/store" + + "github.com/stretchr/testify/require" +) + +func TestImportFixture(t *testing.T) { + codec, err := crypto.NewAesGcmCodecFromBase64Key("ija/CsKe9xs4RSia1SY/oVwMzMR2t5Fh3gd1GggbocY=") + require.NoError(t, err) + + t.Run("successful import on empty database", func(t *testing.T) { + st, cleanup := setupTestStore(t, codec) + defer cleanup() + + // Encrypt a test secret using the same codec + encryptedSecret, err := codec.Encrypt("test-recovery-key-12345") + require.NoError(t, err) + + // Create a test fixture file with properly encrypted secret + fixtureData := `{ + "computers": [ + {"id": 1, "serial": "ABC123", "username": "testuser", "computername": "Test Mac", "last_checkin": "2024-01-15T10:30:00Z"} + ], + "secrets": [ + {"id": 1, "computer_id": 1, "secret_type": "recovery_key", "secret": "` + encryptedSecret + `", "date_escrowed": "2024-01-15T10:30:00Z", "rotation_required": false} + ], + "users": [ + {"id": 1, "username": "admin", "email": "admin@example.com", "is_staff": true, "is_superuser": true, "can_approve": true, "groups": [], "password_hash": "", "must_reset_password": true, "local_login_enabled": false, "auth_source": "saml"} + ], + "requests": [ + {"id": 1, "secret_id": 1, "requesting_user": "admin", "approved": true, "auth_user": "admin", "reason_for_request": "Test", "reason_for_approval": "Approved", "date_requested": "2024-01-15T11:00:00Z", "date_approved": "2024-01-15T11:05:00Z", "current": true} + ] + }` + + fixturePath := filepath.Join(t.TempDir(), "fixture.json") + err = os.WriteFile(fixturePath, []byte(fixtureData), 0o600) + require.NoError(t, err) + + err = importFixture(st, fixturePath) + require.NoError(t, err) + + // Verify computer was imported + computers, err := st.ListComputers() + require.NoError(t, err) + require.Len(t, computers, 1) + require.Equal(t, "ABC123", computers[0].Serial) + require.Equal(t, "testuser", computers[0].Username) + require.Equal(t, "Test Mac", computers[0].ComputerName) + + // Verify user was imported + users, err := st.ListUsers() + require.NoError(t, err) + require.Len(t, users, 1) + require.Equal(t, "admin", users[0].Username) + require.True(t, users[0].IsStaff) + require.True(t, users[0].CanApprove) + + // Verify secret was imported and can be decrypted + secrets, err := st.ListSecretsByComputer(1) + require.NoError(t, err) + require.Len(t, secrets, 1) + require.Equal(t, "recovery_key", secrets[0].SecretType) + require.Equal(t, "test-recovery-key-12345", secrets[0].Secret) // Decrypted value + + // Verify request was imported + requests, err := st.ListRequestsBySecret(1) + require.NoError(t, err) + require.Len(t, requests, 1) + require.Equal(t, "admin", requests[0].RequestingUser) + require.NotNil(t, requests[0].Approved) + require.True(t, *requests[0].Approved) + }) + + t.Run("import fails on non-empty database", func(t *testing.T) { + st, cleanup := setupTestStore(t, codec) + defer cleanup() + + // Add existing data + _, err := st.AddUser("existing", "", true, false, false, false, "local") + require.NoError(t, err) + + fixtureData := `{"computers": [], "secrets": [], "users": [], "requests": []}` + fixturePath := filepath.Join(t.TempDir(), "fixture.json") + err = os.WriteFile(fixturePath, []byte(fixtureData), 0o600) + require.NoError(t, err) + + err = importFixture(st, fixturePath) + require.Error(t, err) + require.ErrorIs(t, err, ErrDatabaseNotEmpty) + }) + + t.Run("import fails with invalid JSON", func(t *testing.T) { + st, cleanup := setupTestStore(t, codec) + defer cleanup() + + fixturePath := filepath.Join(t.TempDir(), "fixture.json") + err := os.WriteFile(fixturePath, []byte("not valid json"), 0o600) + require.NoError(t, err) + + err = importFixture(st, fixturePath) + require.Error(t, err) + require.Contains(t, err.Error(), "parse fixture") + }) + + t.Run("import fails with missing file", func(t *testing.T) { + st, cleanup := setupTestStore(t, codec) + defer cleanup() + + err := importFixture(st, "/nonexistent/path/fixture.json") + require.Error(t, err) + require.Contains(t, err.Error(), "read fixture file") + }) +} + +func TestParseDateTimeFormats(t *testing.T) { + tests := []struct { + name string + input string + expected time.Time + wantErr bool + }{ + { + name: "RFC3339", + input: "2024-01-15T10:30:00Z", + expected: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC), + }, + { + name: "RFC3339 with milliseconds", + input: "2024-01-15T10:30:00.000Z", + expected: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC), + }, + { + name: "space-separated datetime", + input: "2024-01-15 10:30:00", + expected: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC), + }, + { + name: "T-separated without Z", + input: "2024-01-15T10:30:00", + expected: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC), + }, + { + name: "empty string", + input: "", + expected: time.Time{}, + }, + { + name: "invalid format", + input: "not-a-date", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseDateTime(tt.input) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.expected, result) + } + }) + } +} + +func TestIsEmpty(t *testing.T) { + codec, err := crypto.NewAesGcmCodecFromBase64Key("ija/CsKe9xs4RSia1SY/oVwMzMR2t5Fh3gd1GggbocY=") + require.NoError(t, err) + + t.Run("returns true for empty database", func(t *testing.T) { + st, cleanup := setupTestStore(t, codec) + defer cleanup() + + isEmpty, err := st.IsEmpty() + require.NoError(t, err) + require.True(t, isEmpty) + }) + + t.Run("returns false when users exist", func(t *testing.T) { + st, cleanup := setupTestStore(t, codec) + defer cleanup() + + _, err := st.AddUser("testuser", "", false, false, false, false, "local") + require.NoError(t, err) + + isEmpty, err := st.IsEmpty() + require.NoError(t, err) + require.False(t, isEmpty) + }) + + t.Run("returns false when computers exist", func(t *testing.T) { + st, cleanup := setupTestStore(t, codec) + defer cleanup() + + _, err := st.AddComputer("SERIAL123", "user", "Computer") + require.NoError(t, err) + + isEmpty, err := st.IsEmpty() + require.NoError(t, err) + require.False(t, isEmpty) + }) +} + +func setupTestStore(t *testing.T, codec store.SecretCodec) (store.Store, func()) { + t.Helper() + + st, err := store.NewSQLiteStore(":memory:", codec) + require.NoError(t, err) + + migrationsFS, err := migrate.SubMigrationsFS(migrate.EmbeddedFS, "sqlite") + require.NoError(t, err) + + err = migrate.Apply(st.DB(), "sqlite", migrationsFS) + require.NoError(t, err) + + return st, func() { + st.DB().Close() + } +} diff --git a/cmd/crypt-server/main.go b/cmd/crypt-server/main.go new file mode 100644 index 0000000..f9f4868 --- /dev/null +++ b/cmd/crypt-server/main.go @@ -0,0 +1,164 @@ +package main + +import ( + "crypt-server/internal/app" + "crypt-server/internal/crypto" + "crypt-server/internal/migrate" + "crypt-server/internal/store" + "flag" + "log" + "net/http" + "os" + "strconv" + "time" + + "github.com/crewjam/saml/samlsp" +) + +func main() { + printMigrations := flag.Bool("print-migrations", false, "Print embedded migrations and exit") + validateMigrations := flag.Bool("validate-migrations", false, "Validate embedded migrations and exit") + migrationsDriver := flag.String("migrations-driver", "", "Migrations driver to target (postgres or sqlite)") + createAdmin := flag.Bool("create-admin", false, "Create the first admin user and exit") + resetPassword := flag.Bool("reset-password", false, "Reset a user's password and exit") + adminUsername := flag.String("username", "", "Username for admin operations") + adminPassword := flag.String("password", "", "Password for admin operations") + importFixturePath := flag.String("import-fixture", "", "Path to fixture JSON file to import (database must be empty)") + flag.Parse() + + logger := log.New(os.Stdout, "crypt-server ", log.LstdFlags) + if *printMigrations || *validateMigrations { + if err := runMigrationCommand(os.Stdout, migrate.EmbeddedFS, *migrationsDriver, *validateMigrations, *printMigrations); err != nil { + logger.Fatalf("migration command failed: %v", err) + } + return + } + + encryptionKey := os.Getenv("FIELD_ENCRYPTION_KEY") + codec, err := crypto.NewAesGcmCodecFromBase64Key(encryptionKey) + if err != nil { + logger.Fatalf("invalid encryption key: %v", err) + } + + dbConfig, err := loadDatabaseConfig() + if err != nil { + logger.Fatal(err) + } + var dataStore store.Store + switch dbConfig.driver { + case "postgres": + postgresStore, err := store.NewPostgresStore(dbConfig.dsn, codec) + if err != nil { + logger.Fatalf("database connection failed: %v", err) + } + pgFS, err := migrate.SubMigrationsFS(migrate.EmbeddedFS, "postgres") + if err != nil { + logger.Fatalf("database migration failed: %v", err) + } + if err := migrate.Apply(postgresStore.DB(), "postgres", pgFS); err != nil { + logger.Fatalf("database migration failed: %v", err) + } + dataStore = postgresStore + logger.Printf("using postgres store") + case "sqlite": + sqliteStore, err := store.NewSQLiteStore(dbConfig.dsn, codec) + if err != nil { + logger.Fatalf("database connection failed: %v", err) + } + sqliteFS, err := migrate.SubMigrationsFS(migrate.EmbeddedFS, "sqlite") + if err != nil { + logger.Fatalf("database migration failed: %v", err) + } + if err := migrate.Apply(sqliteStore.DB(), "sqlite", sqliteFS); err != nil { + logger.Fatalf("database migration failed: %v", err) + } + dataStore = sqliteStore + logger.Printf("using sqlite store") + default: + logger.Fatalf("unsupported database driver: %s", dbConfig.driver) + } + + // Wrap store with logging + dataStore = store.NewLoggingStore(dataStore, logger) + + if *createAdmin { + if err := createFirstAdmin(dataStore, *adminUsername, *adminPassword); err != nil { + logger.Fatalf("create admin failed: %v", err) + } + logger.Printf("created first admin user: %s", *adminUsername) + return + } + + if *resetPassword { + if err := resetUserPassword(dataStore, *adminUsername, *adminPassword); err != nil { + logger.Fatalf("reset password failed: %v", err) + } + logger.Printf("password reset for user: %s", *adminUsername) + return + } + + if *importFixturePath != "" { + logger.Printf("importing fixture from %s", *importFixturePath) + if err := importFixture(dataStore, *importFixturePath); err != nil { + logger.Fatalf("import fixture failed: %v", err) + } + logger.Printf("fixture imported successfully") + return + } + renderer := app.NewRenderer("web/templates/layouts/base.html", "web/templates/pages") + sessionKey := os.Getenv("SESSION_KEY") + if sessionKey == "" { + logger.Fatal("SESSION_KEY is required") + } + sessionTTL := 24 * time.Hour + sessionManager, err := app.NewSessionManager([]byte(sessionKey), "crypt_session", sessionTTL) + if err != nil { + logger.Fatalf("invalid session configuration: %v", err) + } + settings := app.Settings{ + ApproveOwn: envBool("APPROVE_OWN", false), + AllApprove: envBool("ALL_APPROVE", false), + SessionTTL: sessionTTL, + CookieSecure: envBool("SESSION_COOKIE_SECURE", false), + RequestCleanupInterval: time.Hour, + RotateViewedSecrets: envBool("ROTATE_VIEWED_SECRETS", false), + } + csrfManager := app.NewCSRFManager("crypt_csrf", 32) + + var samlSP *samlsp.Middleware + var samlConfig *app.SAMLConfig + samlConfigPath := os.Getenv("SAML_CONFIG_FILE") + if samlConfigPath != "" { + cfg, err := app.LoadSAMLConfig(samlConfigPath) + if err != nil { + logger.Fatalf("invalid saml config: %v", err) + } + samlProvider, err := app.BuildSAMLProvider(cfg) + if err != nil { + logger.Fatalf("saml setup failed: %v", err) + } + samlSP = samlProvider + samlConfig = cfg + logger.Printf("saml enabled") + } + + server := app.NewServer(dataStore, renderer, logger, sessionManager, csrfManager, samlSP, samlConfig, settings) + + addr := ":8080" + logger.Printf("listening on %s", addr) + if err := http.ListenAndServe(addr, server.Routes()); err != nil { + logger.Fatalf("server stopped: %v", err) + } +} + +func envBool(key string, fallback bool) bool { + raw := os.Getenv(key) + if raw == "" { + return fallback + } + parsed, err := strconv.ParseBool(raw) + if err != nil { + return fallback + } + return parsed +} diff --git a/cmd/crypt-server/migrations.go b/cmd/crypt-server/migrations.go new file mode 100644 index 0000000..f1dbe32 --- /dev/null +++ b/cmd/crypt-server/migrations.go @@ -0,0 +1,64 @@ +package main + +import ( + "fmt" + "io" + "io/fs" + "strings" + + "crypt-server/internal/migrate" +) + +func migrationDrivers(driver string) ([]string, error) { + switch strings.TrimSpace(driver) { + case "": + return []string{"postgres", "sqlite"}, nil + case "postgres": + return []string{"postgres"}, nil + case "sqlite": + return []string{"sqlite"}, nil + default: + return nil, fmt.Errorf("unsupported migrations driver: %s", driver) + } +} + +func runMigrationCommand(w io.Writer, fsys fs.FS, driver string, validate, print bool) error { + drivers, err := migrationDrivers(driver) + if err != nil { + return err + } + for _, dbDriver := range drivers { + sub, err := migrate.SubMigrationsFS(fsys, dbDriver) + if err != nil { + return fmt.Errorf("load %s migrations: %w", dbDriver, err) + } + if validate { + if err := migrate.Validate(sub); err != nil { + return fmt.Errorf("%s migrations invalid: %w", dbDriver, err) + } + } + if print { + if err := printMigrations(w, sub, dbDriver); err != nil { + return err + } + } + } + return nil +} + +func printMigrations(w io.Writer, fsys fs.FS, driver string) error { + migrations, err := migrate.List(fsys) + if err != nil { + return fmt.Errorf("list %s migrations: %w", driver, err) + } + if len(migrations) == 0 { + return fmt.Errorf("no %s migrations found", driver) + } + fmt.Fprintf(w, "== %s ==\n", driver) + for _, migration := range migrations { + fmt.Fprintf(w, "-- %03d %s\n", migration.Version, migration.Name) + fmt.Fprintf(w, "%s\n", strings.TrimRight(migration.SQL, "\n")) + fmt.Fprintln(w) + } + return nil +} diff --git a/cmd/crypt-server/migrations_test.go b/cmd/crypt-server/migrations_test.go new file mode 100644 index 0000000..24aefaa --- /dev/null +++ b/cmd/crypt-server/migrations_test.go @@ -0,0 +1,50 @@ +package main + +import ( + "bytes" + "testing" + "testing/fstest" + + "github.com/stretchr/testify/require" +) + +func TestMigrationDrivers(t *testing.T) { + drivers, err := migrationDrivers("") + require.NoError(t, err) + require.Equal(t, []string{"postgres", "sqlite"}, drivers) + + drivers, err = migrationDrivers("postgres") + require.NoError(t, err) + require.Equal(t, []string{"postgres"}, drivers) + + drivers, err = migrationDrivers("sqlite") + require.NoError(t, err) + require.Equal(t, []string{"sqlite"}, drivers) + + _, err = migrationDrivers("mysql") + require.Error(t, err) +} + +func TestRunMigrationCommandPrints(t *testing.T) { + fsys := fstest.MapFS{ + "migrations/postgres/001_init.sql": {Data: []byte("CREATE TABLE a (id INTEGER);")}, + "migrations/sqlite/001_init.sql": {Data: []byte("CREATE TABLE a (id INTEGER);")}, + } + var buf bytes.Buffer + + err := runMigrationCommand(&buf, fsys, "postgres", false, true) + require.NoError(t, err) + require.Contains(t, buf.String(), "== postgres ==") + require.Contains(t, buf.String(), "001_init.sql") +} + +func TestRunMigrationCommandValidates(t *testing.T) { + fsys := fstest.MapFS{ + "migrations/sqlite/001_init.sql": {Data: []byte("CREATE TABLE a (id INTEGER);")}, + } + + var buf bytes.Buffer + err := runMigrationCommand(&buf, fsys, "sqlite", true, false) + require.NoError(t, err) + require.Equal(t, "", buf.String()) +} diff --git a/cmd/cryptctl/fixture.go b/cmd/cryptctl/fixture.go new file mode 100644 index 0000000..51aa2fd --- /dev/null +++ b/cmd/cryptctl/fixture.go @@ -0,0 +1,306 @@ +package main + +import ( + "encoding/json" + "errors" + "fmt" + "sort" + "strings" + "time" + + "crypt-server/internal/crypto" + "github.com/fernet/fernet-go" +) + +func parseFixture(data []byte) ([]fixtureEntry, error) { + var entries []fixtureEntry + if err := json.Unmarshal(data, &entries); err != nil { + return nil, err + } + return entries, nil +} + +func convertFixture(entries []fixtureEntry, legacyKey *fernet.Key, newCodec *crypto.AesGcmCodec, passwordMap map[string]passwordMapEntry) (*migrationOutput, error) { + users := make(map[int]userOut) + usernames := make(map[int]string) + emails := make(map[int]string) + groups := make(map[int]string) + permissionIDs := make(map[int]struct{}) + userPermissions := make(map[int][]int) + groupPermissions := make(map[int][]int) + userGroups := make(map[int][]int) + computers := make([]computerOut, 0) + secrets := make([]secretOut, 0) + requests := make([]requestOut, 0) + + for _, entry := range entries { + pk := entry.pkInt() + switch entry.Model { + case "auth.user": + username := getString(entry.Fields, "username") + user := userOut{ + ID: pk, + Username: username, + Email: getString(entry.Fields, "email"), + IsStaff: getBool(entry.Fields, "is_staff"), + IsSuper: getBool(entry.Fields, "is_superuser"), + Groups: []string{}, + } + users[pk] = user + usernames[pk] = username + emails[pk] = strings.ToLower(getString(entry.Fields, "email")) + case "auth.group": + groups[pk] = getString(entry.Fields, "name") + case "auth.permission": + if getString(entry.Fields, "codename") == "can_approve" { + permissionIDs[pk] = struct{}{} + } + case "auth.user_user_permissions": + userID := getInt(entry.Fields, "user") + permID := getInt(entry.Fields, "permission") + userPermissions[userID] = append(userPermissions[userID], permID) + case "auth.group_permissions": + groupID := getInt(entry.Fields, "group") + permID := getInt(entry.Fields, "permission") + groupPermissions[groupID] = append(groupPermissions[groupID], permID) + case "auth.user_groups": + userID := getInt(entry.Fields, "user") + groupID := getInt(entry.Fields, "group") + userGroups[userID] = append(userGroups[userID], groupID) + } + } + + for _, entry := range entries { + pk := entry.pkInt() + switch entry.Model { + case "server.computer": + computers = append(computers, computerOut{ + ID: pk, + Serial: getString(entry.Fields, "serial"), + Username: getString(entry.Fields, "username"), + ComputerName: getString(entry.Fields, "computername"), + LastCheckin: getString(entry.Fields, "last_checkin"), + }) + case "server.secret": + ciphertext := getString(entry.Fields, "secret") + plaintext, err := decryptLegacySecret(ciphertext, legacyKey) + if err != nil { + return nil, fmt.Errorf("decrypt secret %d: %w", pk, err) + } + encrypted, err := newCodec.Encrypt(plaintext) + if err != nil { + return nil, fmt.Errorf("encrypt secret %d: %w", pk, err) + } + secrets = append(secrets, secretOut{ + ID: pk, + ComputerID: getInt(entry.Fields, "computer"), + SecretType: getString(entry.Fields, "secret_type"), + Secret: encrypted, + DateEscrowed: getString(entry.Fields, "date_escrowed"), + RotationRequired: getBool(entry.Fields, "rotation_required"), + }) + case "server.request": + requestingUser := usernameForID(usernames, getOptionalInt(entry.Fields, "requesting_user")) + authUser := usernameForID(usernames, getOptionalInt(entry.Fields, "auth_user")) + requests = append(requests, requestOut{ + ID: pk, + SecretID: getInt(entry.Fields, "secret"), + RequestingUser: requestingUser, + Approved: getOptionalBool(entry.Fields, "approved"), + AuthUser: authUser, + ReasonForRequest: getString(entry.Fields, "reason_for_request"), + ReasonForApproval: getString(entry.Fields, "reason_for_approval"), + DateRequested: getString(entry.Fields, "date_requested"), + DateApproved: getString(entry.Fields, "date_approved"), + Current: getBool(entry.Fields, "current"), + }) + } + } + + userList := make([]userOut, 0, len(users)) + for id, user := range users { + groupNames := mapGroups(userGroups[id], groups) + user.Groups = groupNames + user.CanApprove = resolveCanApprove(user, userPermissions[id], groupPermissions, userGroups[id], permissionIDs) + applyPasswordMapping(&user, passwordMap, emails[id]) + userList = append(userList, user) + } + sort.Slice(userList, func(i, j int) bool { + return userList[i].ID < userList[j].ID + }) + + return &migrationOutput{ + Computers: computers, + Secrets: secrets, + Requests: requests, + Users: userList, + }, nil +} + +func mapGroups(groupIDs []int, groups map[int]string) []string { + unique := map[string]struct{}{} + for _, groupID := range groupIDs { + if name := groups[groupID]; name != "" { + unique[name] = struct{}{} + } + } + names := make([]string, 0, len(unique)) + for name := range unique { + names = append(names, name) + } + sort.Strings(names) + return names +} + +func resolveCanApprove(user userOut, userPerms []int, groupPerms map[int][]int, userGroupIDs []int, canApproveIDs map[int]struct{}) bool { + if user.IsSuper { + return true + } + if hasPermission(userPerms, canApproveIDs) { + return true + } + for _, groupID := range userGroupIDs { + if hasPermission(groupPerms[groupID], canApproveIDs) { + return true + } + } + return false +} + +func hasPermission(permissionIDs []int, allowed map[int]struct{}) bool { + for _, permID := range permissionIDs { + if _, ok := allowed[permID]; ok { + return true + } + } + return false +} + +func applyPasswordMapping(user *userOut, passwordMap map[string]passwordMapEntry, email string) { + user.MustResetPassword = true + user.LocalLoginEnabled = false + user.AuthSource = "saml" + if entry, ok := passwordMap[strings.ToLower(user.Username)]; ok { + user.PasswordHash = entry.PasswordHash + user.MustResetPassword = entry.MustResetPassword + user.LocalLoginEnabled = true + user.AuthSource = "local" + return + } + if email != "" { + if entry, ok := passwordMap[email]; ok { + user.PasswordHash = entry.PasswordHash + user.MustResetPassword = entry.MustResetPassword + user.LocalLoginEnabled = true + user.AuthSource = "local" + } + } +} + +func decryptLegacySecret(value string, key *fernet.Key) (string, error) { + if value == "" { + return "", errors.New("empty secret") + } + // Try Fernet decryption first (for encrypted Django databases) + plaintext := fernet.VerifyAndDecrypt([]byte(value), 0*time.Second, []*fernet.Key{key}) + if plaintext != nil { + return string(plaintext), nil + } + // Fallback: if Fernet decryption fails, assume the secret is plaintext + // (some Django installations may not have field encryption enabled) + return value, nil +} + +func marshalOutput(output *migrationOutput) ([]byte, error) { + return json.MarshalIndent(output, "", " ") +} + +func usernameForID(users map[int]string, id *int) string { + if id == nil { + return "" + } + if username, ok := users[*id]; ok { + return username + } + return fmt.Sprintf("user-%d", *id) +} + +func getString(fields map[string]interface{}, key string) string { + value, ok := fields[key] + if !ok || value == nil { + return "" + } + switch typed := value.(type) { + case string: + return typed + default: + return fmt.Sprintf("%v", typed) + } +} + +func getInt(fields map[string]interface{}, key string) int { + value, ok := fields[key] + if !ok || value == nil { + return 0 + } + switch typed := value.(type) { + case float64: + return int(typed) + case int: + return typed + case json.Number: + parsed, _ := typed.Int64() + return int(parsed) + default: + return 0 + } +} + +func getOptionalInt(fields map[string]interface{}, key string) *int { + value, ok := fields[key] + if !ok || value == nil { + return nil + } + switch typed := value.(type) { + case float64: + value := int(typed) + return &value + case int: + return &typed + case json.Number: + parsed, _ := typed.Int64() + value := int(parsed) + return &value + default: + return nil + } +} + +func getBool(fields map[string]interface{}, key string) bool { + value, ok := fields[key] + if !ok || value == nil { + return false + } + switch typed := value.(type) { + case bool: + return typed + default: + return false + } +} + +func getOptionalBool(fields map[string]interface{}, key string) *bool { + value, ok := fields[key] + if !ok { + return nil + } + if value == nil { + return nil + } + switch typed := value.(type) { + case bool: + return &typed + default: + return nil + } +} diff --git a/cmd/cryptctl/fixture_error_test.go b/cmd/cryptctl/fixture_error_test.go new file mode 100644 index 0000000..2b298ad --- /dev/null +++ b/cmd/cryptctl/fixture_error_test.go @@ -0,0 +1,79 @@ +package main + +import ( + "encoding/base64" + "encoding/json" + "os" + "path/filepath" + "testing" + + "crypt-server/internal/crypto" + "github.com/fernet/fernet-go" + "github.com/stretchr/testify/require" +) + +func TestConvertFixturePlaintextSecretFallback(t *testing.T) { + // Test that non-Fernet secrets are treated as plaintext (fallback behavior) + // This supports Django databases that didn't have field encryption enabled + legacyKey := fernet.Key{} + require.NoError(t, legacyKey.Generate()) + legacyDecoded := fernet.MustDecodeKeys(string(legacyKey.Encode())) + + codec, err := crypto.NewAesGcmCodecFromBase64Key(validKeyBase64()) + require.NoError(t, err) + + entries := []fixtureEntry{ + {Model: "server.secret", PK: 20, Fields: map[string]interface{}{"computer": 10, "secret_type": "recovery_key", "secret": "plaintext-recovery-key"}}, + } + + output, err := convertFixture(entries, legacyDecoded[0], codec, map[string]passwordMapEntry{}) + require.NoError(t, err) + require.Len(t, output.Secrets, 1) + // Verify the plaintext was encrypted with the new codec + decrypted, err := codec.Decrypt(output.Secrets[0].Secret) + require.NoError(t, err) + require.Equal(t, "plaintext-recovery-key", decrypted) +} + +func TestConvertFixtureMissingSecret(t *testing.T) { + legacyKey := fernet.Key{} + require.NoError(t, legacyKey.Generate()) + legacyDecoded := fernet.MustDecodeKeys(string(legacyKey.Encode())) + + codec, err := crypto.NewAesGcmCodecFromBase64Key(validKeyBase64()) + require.NoError(t, err) + + entries := []fixtureEntry{ + {Model: "server.secret", PK: 21, Fields: map[string]interface{}{"computer": 10, "secret_type": "recovery_key"}}, + } + + _, err = convertFixture(entries, legacyDecoded[0], codec, map[string]passwordMapEntry{}) + require.Error(t, err) +} + +func TestRunImportFixtureMissingArgs(t *testing.T) { + err := runImportFixture([]string{"--input", "", "--output", ""}) + require.Error(t, err) +} + +func TestRunImportFixtureMissingKeys(t *testing.T) { + tmp := t.TempDir() + entries := []fixtureEntry{{Model: "server.computer", PK: 10, Fields: map[string]interface{}{"serial": "SERIAL"}}} + payload, err := json.Marshal(entries) + require.NoError(t, err) + + inputPath := filepath.Join(tmp, "fixture.json") + outputPath := filepath.Join(tmp, "output.json") + require.NoError(t, os.WriteFile(inputPath, payload, 0o600)) + + err = runImportFixture([]string{"--input", inputPath, "--output", outputPath}) + require.Error(t, err) +} + +func validKeyBase64() string { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i + 1) + } + return base64.StdEncoding.EncodeToString(key) +} diff --git a/cmd/cryptctl/fixture_helpers_test.go b/cmd/cryptctl/fixture_helpers_test.go new file mode 100644 index 0000000..d384cc6 --- /dev/null +++ b/cmd/cryptctl/fixture_helpers_test.go @@ -0,0 +1,294 @@ +package main + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetString(t *testing.T) { + tests := []struct { + name string + fields map[string]interface{} + key string + expected string + }{ + {"string value", map[string]interface{}{"foo": "bar"}, "foo", "bar"}, + {"missing key", map[string]interface{}{}, "foo", ""}, + {"nil value", map[string]interface{}{"foo": nil}, "foo", ""}, + {"int value", map[string]interface{}{"foo": 42}, "foo", "42"}, + {"float value", map[string]interface{}{"foo": 3.14}, "foo", "3.14"}, + {"bool value", map[string]interface{}{"foo": true}, "foo", "true"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, getString(tt.fields, tt.key)) + }) + } +} + +func TestGetInt(t *testing.T) { + jsonNum := json.Number("42") + tests := []struct { + name string + fields map[string]interface{} + key string + expected int + }{ + {"float64 value", map[string]interface{}{"foo": float64(42)}, "foo", 42}, + {"int value", map[string]interface{}{"foo": 42}, "foo", 42}, + {"json.Number value", map[string]interface{}{"foo": jsonNum}, "foo", 42}, + {"missing key", map[string]interface{}{}, "foo", 0}, + {"nil value", map[string]interface{}{"foo": nil}, "foo", 0}, + {"string value", map[string]interface{}{"foo": "bar"}, "foo", 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, getInt(tt.fields, tt.key)) + }) + } +} + +func TestGetOptionalInt(t *testing.T) { + jsonNum := json.Number("42") + t.Run("float64 value", func(t *testing.T) { + result := getOptionalInt(map[string]interface{}{"foo": float64(42)}, "foo") + require.NotNil(t, result) + require.Equal(t, 42, *result) + }) + + t.Run("int value", func(t *testing.T) { + result := getOptionalInt(map[string]interface{}{"foo": 42}, "foo") + require.NotNil(t, result) + require.Equal(t, 42, *result) + }) + + t.Run("json.Number value", func(t *testing.T) { + result := getOptionalInt(map[string]interface{}{"foo": jsonNum}, "foo") + require.NotNil(t, result) + require.Equal(t, 42, *result) + }) + + t.Run("missing key", func(t *testing.T) { + result := getOptionalInt(map[string]interface{}{}, "foo") + require.Nil(t, result) + }) + + t.Run("nil value", func(t *testing.T) { + result := getOptionalInt(map[string]interface{}{"foo": nil}, "foo") + require.Nil(t, result) + }) + + t.Run("string value", func(t *testing.T) { + result := getOptionalInt(map[string]interface{}{"foo": "bar"}, "foo") + require.Nil(t, result) + }) +} + +func TestGetBool(t *testing.T) { + tests := []struct { + name string + fields map[string]interface{} + key string + expected bool + }{ + {"true value", map[string]interface{}{"foo": true}, "foo", true}, + {"false value", map[string]interface{}{"foo": false}, "foo", false}, + {"missing key", map[string]interface{}{}, "foo", false}, + {"nil value", map[string]interface{}{"foo": nil}, "foo", false}, + {"string value", map[string]interface{}{"foo": "true"}, "foo", false}, + {"int value", map[string]interface{}{"foo": 1}, "foo", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, getBool(tt.fields, tt.key)) + }) + } +} + +func TestGetOptionalBool(t *testing.T) { + t.Run("true value", func(t *testing.T) { + result := getOptionalBool(map[string]interface{}{"foo": true}, "foo") + require.NotNil(t, result) + require.True(t, *result) + }) + + t.Run("false value", func(t *testing.T) { + result := getOptionalBool(map[string]interface{}{"foo": false}, "foo") + require.NotNil(t, result) + require.False(t, *result) + }) + + t.Run("missing key", func(t *testing.T) { + result := getOptionalBool(map[string]interface{}{}, "foo") + require.Nil(t, result) + }) + + t.Run("nil value", func(t *testing.T) { + result := getOptionalBool(map[string]interface{}{"foo": nil}, "foo") + require.Nil(t, result) + }) + + t.Run("string value", func(t *testing.T) { + result := getOptionalBool(map[string]interface{}{"foo": "true"}, "foo") + require.Nil(t, result) + }) +} + +func TestUsernameForID(t *testing.T) { + users := map[int]string{ + 1: "admin", + 2: "user", + } + + t.Run("existing user", func(t *testing.T) { + id := 1 + require.Equal(t, "admin", usernameForID(users, &id)) + }) + + t.Run("nil id", func(t *testing.T) { + require.Equal(t, "", usernameForID(users, nil)) + }) + + t.Run("missing user", func(t *testing.T) { + id := 999 + require.Equal(t, "user-999", usernameForID(users, &id)) + }) +} + +func TestMapGroups(t *testing.T) { + groups := map[int]string{ + 1: "admins", + 2: "approvers", + 3: "users", + } + + t.Run("multiple groups", func(t *testing.T) { + result := mapGroups([]int{1, 2}, groups) + require.Equal(t, []string{"admins", "approvers"}, result) + }) + + t.Run("empty group ids", func(t *testing.T) { + result := mapGroups([]int{}, groups) + require.Empty(t, result) + }) + + t.Run("missing group", func(t *testing.T) { + result := mapGroups([]int{1, 999}, groups) + require.Equal(t, []string{"admins"}, result) + }) + + t.Run("duplicate groups", func(t *testing.T) { + result := mapGroups([]int{1, 1, 2}, groups) + require.Equal(t, []string{"admins", "approvers"}, result) + }) +} + +func TestHasPermission(t *testing.T) { + allowed := map[int]struct{}{ + 10: {}, + 20: {}, + } + + t.Run("has permission", func(t *testing.T) { + require.True(t, hasPermission([]int{10}, allowed)) + require.True(t, hasPermission([]int{5, 10, 15}, allowed)) + }) + + t.Run("no permission", func(t *testing.T) { + require.False(t, hasPermission([]int{5, 15}, allowed)) + require.False(t, hasPermission([]int{}, allowed)) + }) +} + +func TestResolveCanApprove(t *testing.T) { + canApproveIDs := map[int]struct{}{100: {}} + + t.Run("superuser can approve", func(t *testing.T) { + user := userOut{IsSuper: true} + require.True(t, resolveCanApprove(user, nil, nil, nil, canApproveIDs)) + }) + + t.Run("user with direct permission", func(t *testing.T) { + user := userOut{} + require.True(t, resolveCanApprove(user, []int{100}, nil, nil, canApproveIDs)) + }) + + t.Run("user with group permission", func(t *testing.T) { + user := userOut{} + groupPerms := map[int][]int{ + 5: {100}, + } + require.True(t, resolveCanApprove(user, nil, groupPerms, []int{5}, canApproveIDs)) + }) + + t.Run("user without permission", func(t *testing.T) { + user := userOut{} + require.False(t, resolveCanApprove(user, []int{50}, nil, nil, canApproveIDs)) + }) +} + +func TestApplyPasswordMapping(t *testing.T) { + passwordMap := map[string]passwordMapEntry{ + "admin": {PasswordHash: "hash1", MustResetPassword: false}, + "user@example.com": {PasswordHash: "hash2", MustResetPassword: true}, + } + + t.Run("match by username", func(t *testing.T) { + user := &userOut{Username: "admin"} + applyPasswordMapping(user, passwordMap, "") + require.Equal(t, "hash1", user.PasswordHash) + require.False(t, user.MustResetPassword) + require.True(t, user.LocalLoginEnabled) + require.Equal(t, "local", user.AuthSource) + }) + + t.Run("match by email", func(t *testing.T) { + user := &userOut{Username: "someone"} + applyPasswordMapping(user, passwordMap, "user@example.com") + require.Equal(t, "hash2", user.PasswordHash) + require.True(t, user.MustResetPassword) + require.True(t, user.LocalLoginEnabled) + require.Equal(t, "local", user.AuthSource) + }) + + t.Run("no match defaults to saml", func(t *testing.T) { + user := &userOut{Username: "unknown"} + applyPasswordMapping(user, passwordMap, "unknown@example.com") + require.Empty(t, user.PasswordHash) + require.True(t, user.MustResetPassword) + require.False(t, user.LocalLoginEnabled) + require.Equal(t, "saml", user.AuthSource) + }) + + t.Run("case insensitive username match", func(t *testing.T) { + user := &userOut{Username: "Admin"} + applyPasswordMapping(user, passwordMap, "") + // Note: the current implementation lowercases username for lookup + require.Equal(t, "hash1", user.PasswordHash) + }) +} + +func TestMarshalOutput(t *testing.T) { + output := &migrationOutput{ + Computers: []computerOut{{ID: 1, Serial: "ABC"}}, + Secrets: []secretOut{}, + Requests: []requestOut{}, + Users: []userOut{}, + } + + data, err := marshalOutput(output) + require.NoError(t, err) + require.Contains(t, string(data), `"serial": "ABC"`) + require.Contains(t, string(data), `"id": 1`) +} + +func TestDecryptLegacySecretEmpty(t *testing.T) { + _, err := decryptLegacySecret("", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "empty secret") +} diff --git a/cmd/cryptctl/fixture_test.go b/cmd/cryptctl/fixture_test.go new file mode 100644 index 0000000..a7b88ac --- /dev/null +++ b/cmd/cryptctl/fixture_test.go @@ -0,0 +1,135 @@ +package main + +import ( + "encoding/base64" + "encoding/json" + "os" + "path/filepath" + "testing" + + "crypt-server/internal/crypto" + "github.com/fernet/fernet-go" + "github.com/stretchr/testify/require" +) + +func TestParseFixture(t *testing.T) { + entries := []fixtureEntry{{Model: "server.computer", PK: 1, Fields: map[string]interface{}{"serial": "ABC"}}} + data, err := json.Marshal(entries) + require.NoError(t, err) + + parsed, err := parseFixture(data) + require.NoError(t, err) + require.Len(t, parsed, 1) + require.Equal(t, "server.computer", parsed[0].Model) +} + +func TestConvertFixture(t *testing.T) { + legacyKey := fernet.Key{} + require.NoError(t, legacyKey.Generate()) + legacyKeyEncoded := legacyKey.Encode() + legacyDecoded := fernet.MustDecodeKeys(string(legacyKeyEncoded)) + + newCodec := testCodec(t) + + ciphertext, err := fernet.EncryptAndSign([]byte("secret"), &legacyKey) + require.NoError(t, err) + entries := []fixtureEntry{ + {Model: "auth.user", PK: 1, Fields: map[string]interface{}{"username": "admin", "email": "admin@example.com", "is_staff": true, "is_superuser": true}}, + {Model: "auth.group", PK: 2, Fields: map[string]interface{}{"name": "approvers"}}, + {Model: "auth.permission", PK: 3, Fields: map[string]interface{}{"codename": "can_approve"}}, + {Model: "auth.group_permissions", PK: 4, Fields: map[string]interface{}{"group": 2, "permission": 3}}, + {Model: "auth.user_groups", PK: 5, Fields: map[string]interface{}{"user": 1, "group": 2}}, + {Model: "server.computer", PK: 10, Fields: map[string]interface{}{"serial": "SERIAL", "username": "user", "computername": "Mac", "last_checkin": "2024-01-01T00:00:00Z"}}, + {Model: "server.secret", PK: 20, Fields: map[string]interface{}{"computer": 10, "secret_type": "recovery_key", "secret": string(ciphertext), "date_escrowed": "2024-01-01T00:00:00Z", "rotation_required": false}}, + {Model: "server.request", PK: 30, Fields: map[string]interface{}{"secret": 20, "requesting_user": 1, "approved": true, "auth_user": 1, "reason_for_request": "Need access", "reason_for_approval": "ok", "date_requested": "2024-01-01T00:00:00Z", "date_approved": "2024-01-01T01:00:00Z", "current": true}}, + {Model: "server.request", PK: 31, Fields: map[string]interface{}{"secret": 20, "requesting_user": 1, "approved": false, "auth_user": 1, "reason_for_request": "Need access again", "reason_for_approval": "denied", "date_requested": "2024-01-02T00:00:00Z", "date_approved": "2024-01-02T01:00:00Z", "current": false}}, + } + + passwordHash, err := hashPasswordForExport("Str0ng!Passw0rd") + require.NoError(t, err) + passwordMap := map[string]passwordMapEntry{ + "admin": {PasswordHash: passwordHash, MustResetPassword: false}, + } + + output, err := convertFixture(entries, legacyDecoded[0], newCodec, passwordMap) + require.NoError(t, err) + require.Len(t, output.Computers, 1) + require.Len(t, output.Secrets, 1) + require.Len(t, output.Requests, 2) + require.Len(t, output.Users, 1) + + decrypted, err := newCodec.Decrypt(output.Secrets[0].Secret) + require.NoError(t, err) + require.Equal(t, "secret", decrypted) + require.Equal(t, "admin", output.Requests[0].RequestingUser) + require.Equal(t, "admin", output.Requests[0].AuthUser) + require.Equal(t, false, *output.Requests[1].Approved) + require.Equal(t, "denied", output.Requests[1].ReasonForApproval) + require.True(t, output.Users[0].CanApprove) + require.Equal(t, []string{"approvers"}, output.Users[0].Groups) + require.True(t, output.Users[0].LocalLoginEnabled) + require.Equal(t, "local", output.Users[0].AuthSource) +} + +func TestRunImportFixture(t *testing.T) { + tmp := t.TempDir() + legacyKey := fernet.Key{} + require.NoError(t, legacyKey.Generate()) + legacyKeyEncoded := legacyKey.Encode() + + newKey := make([]byte, 32) + for i := range newKey { + newKey[i] = byte(i + 1) + } + newKeyEncoded := base64.StdEncoding.EncodeToString(newKey) + + ciphertext, err := fernet.EncryptAndSign([]byte("secret"), &legacyKey) + require.NoError(t, err) + entries := []fixtureEntry{ + {Model: "server.computer", PK: 10, Fields: map[string]interface{}{"serial": "SERIAL", "username": "user", "computername": "Mac"}}, + {Model: "server.secret", PK: 20, Fields: map[string]interface{}{"computer": 10, "secret_type": "recovery_key", "secret": string(ciphertext)}}, + } + payload, err := json.Marshal(entries) + require.NoError(t, err) + + inputPath := filepath.Join(tmp, "fixture.json") + outputPath := filepath.Join(tmp, "output.json") + legacyKeyPath := filepath.Join(tmp, "legacy.key") + newKeyPath := filepath.Join(tmp, "new.key") + + require.NoError(t, os.WriteFile(inputPath, payload, 0o600)) + require.NoError(t, os.WriteFile(legacyKeyPath, []byte(legacyKeyEncoded), 0o600)) + require.NoError(t, os.WriteFile(newKeyPath, []byte(newKeyEncoded), 0o600)) + + err = runImportFixture([]string{ + "--input", inputPath, + "--output", outputPath, + "--legacy-key-file", legacyKeyPath, + "--new-key-file", newKeyPath, + }) + require.NoError(t, err) + + outputBytes, err := os.ReadFile(outputPath) + require.NoError(t, err) + var output migrationOutput + require.NoError(t, json.Unmarshal(outputBytes, &output)) + require.Len(t, output.Secrets, 1) + + codec, err := crypto.NewAesGcmCodecFromBase64Key(newKeyEncoded) + require.NoError(t, err) + plaintext, err := codec.Decrypt(output.Secrets[0].Secret) + require.NoError(t, err) + require.Equal(t, "secret", plaintext) +} + +func testCodec(t *testing.T) *crypto.AesGcmCodec { + t.Helper() + key := make([]byte, 32) + for i := range key { + key[i] = byte(i + 1) + } + encoded := base64.StdEncoding.EncodeToString(key) + codec, err := crypto.NewAesGcmCodecFromBase64Key(encoded) + require.NoError(t, err) + return codec +} diff --git a/cmd/cryptctl/integration.go b/cmd/cryptctl/integration.go new file mode 100644 index 0000000..3367b54 --- /dev/null +++ b/cmd/cryptctl/integration.go @@ -0,0 +1,240 @@ +package main + +import ( + "database/sql" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "time" + + "crypt-server/internal/app" + "crypt-server/internal/crypto" + "crypt-server/internal/migrate" + "crypt-server/internal/store" + + _ "github.com/lib/pq" + _ "modernc.org/sqlite" +) + +func runIntegrationTest(args []string, stdout io.Writer) error { + fs := flag.NewFlagSet("integration-test", flag.ExitOnError) + dbType := fs.String("db", "sqlite", "Database type: sqlite or postgres") + dbURL := fs.String("db-url", "", "Database URL (for postgres) or file path (for sqlite)") + encryptionKey := fs.String("key", "", "Base64 FIELD_ENCRYPTION_KEY") + encryptionKeyFile := fs.String("key-file", "", "Path to file containing FIELD_ENCRYPTION_KEY") + fs.Parse(args) + + keyValue, err := loadKey(*encryptionKey, *encryptionKeyFile, "FIELD_ENCRYPTION_KEY") + if err != nil { + return fmt.Errorf("load encryption key: %w", err) + } + + codec, err := crypto.NewAesGcmCodecFromBase64Key(keyValue) + if err != nil { + return fmt.Errorf("create codec: %w", err) + } + + var st store.Store + var db *sql.DB + + switch *dbType { + case "sqlite": + dsn := *dbURL + if dsn == "" { + dsn = ":memory:" + } + sqliteStore, err := store.NewSQLiteStore(dsn, codec) + if err != nil { + return fmt.Errorf("open sqlite: %w", err) + } + st = sqliteStore + db = sqliteStore.DB() + case "postgres": + if *dbURL == "" { + return errors.New("db-url is required for postgres") + } + pgStore, err := store.NewPostgresStore(*dbURL, codec) + if err != nil { + return fmt.Errorf("open postgres: %w", err) + } + st = pgStore + db = pgStore.DB() + default: + return fmt.Errorf("unsupported database type: %s", *dbType) + } + defer db.Close() + + // Run migrations using the embedded migration files + fmt.Fprintln(stdout, "Running migrations...") + migrationsFS, err := migrate.SubMigrationsFS(migrate.EmbeddedFS, *dbType) + if err != nil { + return fmt.Errorf("load migrations: %w", err) + } + if err := migrate.Apply(db, *dbType, migrationsFS); err != nil { + return fmt.Errorf("run migrations: %w", err) + } + + // Create test server with minimal configuration + fmt.Fprintln(stdout, "Creating test server...") + logger := log.New(io.Discard, "", 0) + renderer := app.NewRenderer("web/templates/base.html", "web/templates") + sessionKey := make([]byte, 32) + sessionManager, err := app.NewSessionManager(sessionKey, "crypt_session", 24*time.Hour) + if err != nil { + return fmt.Errorf("create session manager: %w", err) + } + csrfManager := app.NewCSRFManager("csrf_token", 32) + settings := app.Settings{} + + server := app.NewServer(st, renderer, logger, sessionManager, csrfManager, nil, nil, settings) + handler := server.Routes() + + testSerial := "TEST-SERIAL-001" + testUsername := "testuser" + testMacName := "Test Mac" + testSecret := "test-recovery-key-12345" + testSecretType := "recovery_key" + + // Test 1: Send initial checkin + fmt.Fprintln(stdout, "\n=== Test 1: Send initial checkin ===") + if err := testCheckin(handler, stdout, testSerial, testUsername, testMacName, testSecret, testSecretType); err != nil { + return fmt.Errorf("test 1 (initial checkin): %w", err) + } + + // Test 2: Verify secret is stored and encrypted correctly + fmt.Fprintln(stdout, "\n=== Test 2: Verify secret retrieval ===") + secretCount, err := testSecretRetrieval(st, stdout, testSerial, testSecretType, testSecret) + if err != nil { + return fmt.Errorf("test 2 (secret retrieval): %w", err) + } + if secretCount != 1 { + return fmt.Errorf("expected 1 secret, got %d", secretCount) + } + + // Test 3: Send duplicate checkin (should NOT create new secret) + fmt.Fprintln(stdout, "\n=== Test 3: Send duplicate checkin ===") + if err := testCheckin(handler, stdout, testSerial, testUsername, testMacName, testSecret, testSecretType); err != nil { + return fmt.Errorf("test 3 (duplicate checkin): %w", err) + } + + // Test 4: Verify no duplicate was created + fmt.Fprintln(stdout, "\n=== Test 4: Verify no duplicate ===") + secretCount, err = testSecretRetrieval(st, stdout, testSerial, testSecretType, testSecret) + if err != nil { + return fmt.Errorf("test 4 (verify no duplicate): %w", err) + } + if secretCount != 1 { + return fmt.Errorf("duplicate secret created! expected 1 secret, got %d", secretCount) + } + fmt.Fprintln(stdout, "PASS: No duplicate secret created") + + // Test 5: Send different secret (should create new entry) + fmt.Fprintln(stdout, "\n=== Test 5: Send different secret ===") + newSecret := "different-recovery-key-67890" + if err := testCheckin(handler, stdout, testSerial, testUsername, testMacName, newSecret, testSecretType); err != nil { + return fmt.Errorf("test 5 (different secret): %w", err) + } + + // Test 6: Verify new secret was created + fmt.Fprintln(stdout, "\n=== Test 6: Verify new secret created ===") + secretCount, err = countSecretsForComputer(st, testSerial, testSecretType) + if err != nil { + return fmt.Errorf("test 6 (count secrets): %w", err) + } + if secretCount != 2 { + return fmt.Errorf("expected 2 secrets after sending different secret, got %d", secretCount) + } + fmt.Fprintln(stdout, "PASS: New secret created for different value") + + fmt.Fprintln(stdout, "\n=== All integration tests passed! ===") + return nil +} + +func testCheckin(handler http.Handler, stdout io.Writer, serial, username, macName, secret, secretType string) error { + form := url.Values{} + form.Set("serial", serial) + form.Set("username", username) + form.Set("macname", macName) + form.Set("recovery_password", secret) + form.Set("secret_type", secretType) + + req := httptest.NewRequest(http.MethodPost, "/checkin/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + return fmt.Errorf("checkin failed with status %d: %s", rec.Code, rec.Body.String()) + } + + var response map[string]any + if err := json.NewDecoder(rec.Body).Decode(&response); err != nil { + return fmt.Errorf("decode response: %w", err) + } + + fmt.Fprintf(stdout, "Checkin response: serial=%v, username=%v, rotation_required=%v\n", + response["serial"], response["username"], response["rotation_required"]) + + return nil +} + +func testSecretRetrieval(st store.Store, stdout io.Writer, serial, secretType, expectedSecret string) (int, error) { + computer, err := st.GetComputerBySerial(serial) + if err != nil { + return 0, fmt.Errorf("get computer: %w", err) + } + fmt.Fprintf(stdout, "Found computer: ID=%d, Serial=%s, Name=%s\n", + computer.ID, computer.Serial, computer.ComputerName) + + secrets, err := st.ListSecretsByComputer(computer.ID) + if err != nil { + return 0, fmt.Errorf("list secrets: %w", err) + } + + count := 0 + for _, s := range secrets { + if s.SecretType == secretType { + count++ + if s.Secret == expectedSecret { + fmt.Fprintf(stdout, "PASS: Secret decrypted correctly: ID=%d, Type=%s\n", s.ID, s.SecretType) + } else if s.Secret != "" { + fmt.Fprintf(stdout, "Secret found: ID=%d, Type=%s\n", s.ID, s.SecretType) + } + } + } + + if count == 0 { + return 0, fmt.Errorf("no secrets found for type %s", secretType) + } + + return count, nil +} + +func countSecretsForComputer(st store.Store, serial, secretType string) (int, error) { + computer, err := st.GetComputerBySerial(serial) + if err != nil { + return 0, fmt.Errorf("get computer: %w", err) + } + + secrets, err := st.ListSecretsByComputer(computer.ID) + if err != nil { + return 0, fmt.Errorf("list secrets: %w", err) + } + + count := 0 + for _, s := range secrets { + if s.SecretType == secretType { + count++ + } + } + return count, nil +} + diff --git a/cmd/cryptctl/integration_test.go b/cmd/cryptctl/integration_test.go new file mode 100644 index 0000000..67e7719 --- /dev/null +++ b/cmd/cryptctl/integration_test.go @@ -0,0 +1,104 @@ +package main + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTestCheckin(t *testing.T) { + t.Run("successful checkin", func(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + require.Equal(t, "/checkin/", r.URL.Path) + require.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type")) + + err := r.ParseForm() + require.NoError(t, err) + require.Equal(t, "TEST-SERIAL", r.FormValue("serial")) + require.Equal(t, "testuser", r.FormValue("username")) + require.Equal(t, "Test Mac", r.FormValue("macname")) + require.Equal(t, "secret123", r.FormValue("recovery_password")) + require.Equal(t, "recovery_key", r.FormValue("secret_type")) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "serial": "TEST-SERIAL", + "username": "testuser", + "rotation_required": false, + }) + }) + + var stdout bytes.Buffer + err := testCheckin(handler, &stdout, "TEST-SERIAL", "testuser", "Test Mac", "secret123", "recovery_key") + require.NoError(t, err) + require.Contains(t, stdout.String(), "serial=TEST-SERIAL") + require.Contains(t, stdout.String(), "username=testuser") + }) + + t.Run("checkin failure", func(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("internal error")) + }) + + err := testCheckin(handler, io.Discard, "TEST-SERIAL", "testuser", "Test Mac", "secret123", "recovery_key") + require.Error(t, err) + require.Contains(t, err.Error(), "checkin failed with status 500") + }) + + t.Run("invalid json response", func(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("not json")) + }) + + err := testCheckin(handler, io.Discard, "TEST-SERIAL", "testuser", "Test Mac", "secret123", "recovery_key") + require.Error(t, err) + require.Contains(t, err.Error(), "decode response") + }) +} + +func TestTestCheckinFormValues(t *testing.T) { + var capturedForm map[string][]string + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.ParseForm() + capturedForm = r.Form + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "serial": r.FormValue("serial"), + "username": r.FormValue("username"), + "rotation_required": false, + }) + }) + + err := testCheckin(handler, io.Discard, "SN-12345", "jdoe", "Johns Mac", "my-secret-key", "filevault") + require.NoError(t, err) + + require.Equal(t, []string{"SN-12345"}, capturedForm["serial"]) + require.Equal(t, []string{"jdoe"}, capturedForm["username"]) + require.Equal(t, []string{"Johns Mac"}, capturedForm["macname"]) + require.Equal(t, []string{"my-secret-key"}, capturedForm["recovery_password"]) + require.Equal(t, []string{"filevault"}, capturedForm["secret_type"]) +} + +func TestCheckinHTTPMethod(t *testing.T) { + var capturedMethod string + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedMethod = r.Method + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{"serial": "X", "username": "Y", "rotation_required": false}) + }) + + req := httptest.NewRequest(http.MethodPost, "/checkin/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + require.Equal(t, http.MethodPost, capturedMethod) +} diff --git a/cmd/cryptctl/key_test.go b/cmd/cryptctl/key_test.go new file mode 100644 index 0000000..e053645 --- /dev/null +++ b/cmd/cryptctl/key_test.go @@ -0,0 +1,36 @@ +package main + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLoadKeyPriority(t *testing.T) { + os.Setenv("FIELD_ENCRYPTION_KEY", "env-key") + t.Cleanup(func() { os.Unsetenv("FIELD_ENCRYPTION_KEY") }) + + tmp := t.TempDir() + filePath := filepath.Join(tmp, "key.txt") + require.NoError(t, os.WriteFile(filePath, []byte("file-key"), 0o600)) + + value, err := loadKey("flag-key", filePath, "FIELD_ENCRYPTION_KEY") + require.NoError(t, err) + require.Equal(t, "flag-key", value) + + value, err = loadKey("", filePath, "FIELD_ENCRYPTION_KEY") + require.NoError(t, err) + require.Equal(t, "file-key", value) + + value, err = loadKey("", "", "FIELD_ENCRYPTION_KEY") + require.NoError(t, err) + require.Equal(t, "env-key", value) +} + +func TestLoadKeyMissing(t *testing.T) { + os.Unsetenv("FIELD_ENCRYPTION_KEY") + _, err := loadKey("", "", "FIELD_ENCRYPTION_KEY") + require.Error(t, err) +} diff --git a/cmd/cryptctl/main.go b/cmd/cryptctl/main.go new file mode 100644 index 0000000..908702b --- /dev/null +++ b/cmd/cryptctl/main.go @@ -0,0 +1,215 @@ +package main + +import ( + "crypto/rand" + "encoding/base64" + "errors" + "flag" + "fmt" + "io" + "os" + "strings" + + "crypt-server/internal/crypto" + "github.com/fernet/fernet-go" +) + +type fixtureEntry struct { + Model string `json:"model"` + PK interface{} `json:"pk"` + Fields map[string]interface{} `json:"fields"` +} + +// pkInt returns the integer PK value, or 0 if not an int/float (e.g., string session keys). +func (e fixtureEntry) pkInt() int { + switch v := e.PK.(type) { + case float64: + return int(v) + case int: + return v + default: + return 0 + } +} + +type migrationOutput struct { + Computers []computerOut `json:"computers"` + Secrets []secretOut `json:"secrets"` + Requests []requestOut `json:"requests"` + Users []userOut `json:"users"` +} + +type computerOut struct { + ID int `json:"id"` + Serial string `json:"serial"` + Username string `json:"username"` + ComputerName string `json:"computername"` + LastCheckin string `json:"last_checkin"` +} + +type secretOut struct { + ID int `json:"id"` + ComputerID int `json:"computer_id"` + SecretType string `json:"secret_type"` + Secret string `json:"secret"` + DateEscrowed string `json:"date_escrowed"` + RotationRequired bool `json:"rotation_required"` +} + +type requestOut struct { + ID int `json:"id"` + SecretID int `json:"secret_id"` + RequestingUser string `json:"requesting_user"` + Approved *bool `json:"approved"` + AuthUser string `json:"auth_user"` + ReasonForRequest string `json:"reason_for_request"` + ReasonForApproval string `json:"reason_for_approval"` + DateRequested string `json:"date_requested"` + DateApproved string `json:"date_approved"` + Current bool `json:"current"` +} + +type userOut struct { + ID int `json:"id"` + Username string `json:"username"` + Email string `json:"email"` + IsStaff bool `json:"is_staff"` + IsSuper bool `json:"is_superuser"` + CanApprove bool `json:"can_approve"` + Groups []string `json:"groups"` + PasswordHash string `json:"password_hash"` + MustResetPassword bool `json:"must_reset_password"` + LocalLoginEnabled bool `json:"local_login_enabled"` + AuthSource string `json:"auth_source"` +} + +func main() { + flag.Usage = func() { + fmt.Fprintln(flag.CommandLine.Output(), "Usage: cryptctl ") + fmt.Fprintln(flag.CommandLine.Output(), "") + fmt.Fprintln(flag.CommandLine.Output(), "Commands:") + fmt.Fprintln(flag.CommandLine.Output(), " gen-key Generate a base64-encoded 32-byte FIELD_ENCRYPTION_KEY") + fmt.Fprintln(flag.CommandLine.Output(), " import-fixture Convert Django JSON fixtures into an encrypted migration export") + fmt.Fprintln(flag.CommandLine.Output(), " integration-test Run integration tests against a database") + } + flag.Parse() + + if flag.NArg() < 1 { + flag.Usage() + os.Exit(2) + } + + switch flag.Arg(0) { + case "gen-key": + if err := runGenKey(os.Stdout); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } + case "import-fixture": + if err := runImportFixture(os.Args[2:]); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } + case "integration-test": + if err := runIntegrationTest(os.Args[2:], os.Stdout); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } + default: + fmt.Fprintf(os.Stderr, "unknown command: %s\n", flag.Arg(0)) + flag.Usage() + os.Exit(2) + } +} + +func runGenKey(w io.Writer) error { + key := make([]byte, 32) + if _, err := rand.Read(key); err != nil { + return fmt.Errorf("generate key: %w", err) + } + encoded := base64.StdEncoding.EncodeToString(key) + _, err := fmt.Fprintln(w, encoded) + return err +} + +func runImportFixture(args []string) error { + fs := flag.NewFlagSet("import-fixture", flag.ExitOnError) + inputPath := fs.String("input", "", "Path to Django JSON fixture file") + outputPath := fs.String("output", "", "Path to write migration export JSON") + legacyKey := fs.String("legacy-key", "", "Base64 legacy FIELD_ENCRYPTION_KEY") + legacyKeyFile := fs.String("legacy-key-file", "", "Path to file containing legacy FIELD_ENCRYPTION_KEY") + newKey := fs.String("new-key", "", "Base64 new FIELD_ENCRYPTION_KEY") + newKeyFile := fs.String("new-key-file", "", "Path to file containing new FIELD_ENCRYPTION_KEY") + passwordMapPath := fs.String("password-map", "", "Path to CSV file mapping usernames/emails to passwords") + fs.Parse(args) + + if *inputPath == "" || *outputPath == "" { + return errors.New("input and output paths are required") + } + + legacyKeyValue, err := loadKey(*legacyKey, *legacyKeyFile, "LEGACY_FIELD_ENCRYPTION_KEY") + if err != nil { + return fmt.Errorf("load legacy key: %w", err) + } + newKeyValue, err := loadKey(*newKey, *newKeyFile, "FIELD_ENCRYPTION_KEY") + if err != nil { + return fmt.Errorf("load new key: %w", err) + } + + legacyFernetKey, err := fernet.DecodeKey(legacyKeyValue) + if err != nil { + return fmt.Errorf("decode legacy key: %w", err) + } + newCodec, err := crypto.NewAesGcmCodecFromBase64Key(newKeyValue) + if err != nil { + return fmt.Errorf("invalid new key: %w", err) + } + + fixtureBytes, err := os.ReadFile(*inputPath) + if err != nil { + return fmt.Errorf("read fixture: %w", err) + } + + entries, err := parseFixture(fixtureBytes) + if err != nil { + return fmt.Errorf("parse fixture: %w", err) + } + + passwordMap, err := loadPasswordMap(*passwordMapPath) + if err != nil { + return fmt.Errorf("load password map: %w", err) + } + + output, err := convertFixture(entries, legacyFernetKey, newCodec, passwordMap) + if err != nil { + return fmt.Errorf("convert fixture: %w", err) + } + + payload, err := marshalOutput(output) + if err != nil { + return fmt.Errorf("encode output: %w", err) + } + + if err := os.WriteFile(*outputPath, payload, 0o600); err != nil { + return fmt.Errorf("write output: %w", err) + } + + return nil +} + +func loadKey(value, path, env string) (string, error) { + if value != "" { + return strings.TrimSpace(value), nil + } + if path != "" { + data, err := os.ReadFile(path) + if err != nil { + return "", err + } + return strings.TrimSpace(string(data)), nil + } + if envValue := os.Getenv(env); envValue != "" { + return strings.TrimSpace(envValue), nil + } + return "", fmt.Errorf("missing key: provide --key, --key-file, or %s", env) +} diff --git a/cmd/cryptctl/password_map.go b/cmd/cryptctl/password_map.go new file mode 100644 index 0000000..ec8ed1b --- /dev/null +++ b/cmd/cryptctl/password_map.go @@ -0,0 +1,103 @@ +package main + +import ( + "crypto/rand" + "encoding/base64" + "encoding/csv" + "errors" + "fmt" + "os" + "strconv" + "strings" + + "golang.org/x/crypto/argon2" +) + +type passwordMapEntry struct { + PasswordHash string + MustResetPassword bool +} + +func loadPasswordMap(path string) (map[string]passwordMapEntry, error) { + if path == "" { + return map[string]passwordMapEntry{}, nil + } + file, err := os.Open(path) + if err != nil { + return nil, err + } + defer file.Close() + + reader := csv.NewReader(file) + records, err := reader.ReadAll() + if err != nil { + return nil, err + } + + out := make(map[string]passwordMapEntry) + for i, record := range records { + if len(record) == 0 { + continue + } + if i == 0 && isPasswordMapHeader(record) { + continue + } + if len(record) < 2 || len(record) > 3 { + return nil, fmt.Errorf("invalid password map record %d", i+1) + } + key := strings.ToLower(strings.TrimSpace(record[0])) + if key == "" { + return nil, fmt.Errorf("invalid password map record %d", i+1) + } + password := strings.TrimSpace(record[1]) + if password == "" { + return nil, fmt.Errorf("password missing for %s", key) + } + mustReset := false + if len(record) == 3 && strings.TrimSpace(record[2]) != "" { + value, err := strconv.ParseBool(strings.TrimSpace(record[2])) + if err != nil { + return nil, fmt.Errorf("invalid must_reset_password value for %s", key) + } + mustReset = value + } + hash, err := hashPasswordForExport(password) + if err != nil { + return nil, err + } + if _, exists := out[key]; exists { + return nil, fmt.Errorf("duplicate password map entry for %s", key) + } + out[key] = passwordMapEntry{ + PasswordHash: hash, + MustResetPassword: mustReset, + } + } + return out, nil +} + +func isPasswordMapHeader(record []string) bool { + if len(record) < 2 { + return false + } + first := strings.ToLower(strings.TrimSpace(record[0])) + return first == "username_or_email" +} + +func hashPasswordForExport(plaintext string) (string, error) { + if plaintext == "" { + return "", errors.New("password is required") + } + salt := make([]byte, 16) + if _, err := rand.Read(salt); err != nil { + return "", fmt.Errorf("generate salt: %w", err) + } + hash := argon2.IDKey([]byte(plaintext), salt, 1, 64*1024, 4, 32) + return fmt.Sprintf("$argon2id$%d$%d$%d$%s$%s", + 1, + 64*1024, + 4, + base64.RawStdEncoding.EncodeToString(salt), + base64.RawStdEncoding.EncodeToString(hash), + ), nil +} diff --git a/cmd/cryptctl/password_map_test.go b/cmd/cryptctl/password_map_test.go new file mode 100644 index 0000000..2587980 --- /dev/null +++ b/cmd/cryptctl/password_map_test.go @@ -0,0 +1,33 @@ +package main + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLoadPasswordMap(t *testing.T) { + tmp := t.TempDir() + path := filepath.Join(tmp, "passwords.csv") + err := os.WriteFile(path, []byte("username_or_email,password,must_reset_password\nadmin,Str0ng!Passw0rd,false\n"), 0o600) + require.NoError(t, err) + + entries, err := loadPasswordMap(path) + require.NoError(t, err) + entry, ok := entries["admin"] + require.True(t, ok) + require.False(t, entry.MustResetPassword) + require.NotEmpty(t, entry.PasswordHash) +} + +func TestLoadPasswordMapErrors(t *testing.T) { + tmp := t.TempDir() + path := filepath.Join(tmp, "passwords.csv") + err := os.WriteFile(path, []byte("user, ,\n"), 0o600) + require.NoError(t, err) + + _, err = loadPasswordMap(path) + require.Error(t, err) +} diff --git a/crypt.wsgi b/crypt.wsgi deleted file mode 100644 index 1badf16..0000000 --- a/crypt.wsgi +++ /dev/null @@ -1,16 +0,0 @@ -import os, sys -import site - -CRYPT_ENV_DIR = '/usr/local/crypt_env' - -# Use site to load the site-packages directory of our virtualenv -site.addsitedir(os.path.join(CRYPT_ENV_DIR, 'lib/python2.7/site-packages')) -# -# # Make sure we have the virtualenv and the Django app itself added to our path -sys.path.append(CRYPT_ENV_DIR) -sys.path.append(os.path.join(CRYPT_ENV_DIR, 'crypt')) - -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'fvserver.settings') - -from django.core.wsgi import get_wsgi_application -application = get_wsgi_application() diff --git a/docker-compose.yml b/docker-compose.yml deleted file mode 100644 index 5dd9242..0000000 --- a/docker-compose.yml +++ /dev/null @@ -1,27 +0,0 @@ -services: - # Uncomment here if you want to use Caddy - # caddy: - # image: "wemakeservices/caddy-docker:latest" - # volumes: - # - ./docker/Caddyfile:/etc/Caddyfile # to mount custom Caddyfile - # ports: - # - "80:80" - # # - "443:443" # uncomment this for https. Make sure you edit the Caddyfile above to reflect your hostname - # depends_on: - # - crypt - # restart: always - - crypt: - image: macadmins/crypt-server - # OR "crypt-server" for local build using documentation in /docs/Docker.md - # build: . # uncomment this to build your own image through Docker Compose - environment: - - FIELD_ENCRYPTION_KEY=jKAv1Sde8m6jCYFnmps0iXkUfAilweNVjbvoebBrDwg= # please change this - - ADMIN_PASS=password - - DEBUG=false - ports: - - "8000:8000" - volumes: - - ${PWD}/crypt.db:/home/app/crypt/crypt.db # This will do a local database. For production you should use postgresql - - ${PWD}/fvserver/settings.py:/home/app/crypt/fvserver/settings.py # Load in your own settings file - restart: always diff --git a/docker/Caddyfile b/docker/Caddyfile deleted file mode 100644 index 9fa74df..0000000 --- a/docker/Caddyfile +++ /dev/null @@ -1,3 +0,0 @@ -*:80 { - proxy / http://crypt:8000 -} \ No newline at end of file diff --git a/docker/README.md b/docker/README.md deleted file mode 100644 index 5d88fe4..0000000 --- a/docker/README.md +++ /dev/null @@ -1,30 +0,0 @@ -__[Crypt][1]__ is a system for centrally storing FileVault 2 recovery keys. It is made up of a client app, and a Django web app for storing the keys. - -This Docker image contains the fully configured Crypt Django web app. A default admin user has been preconfigured, use admin/password to login. -If you intend on using the server for anything semi-serious it is a good idea to change the password or add a new admin user and delete the default one. - -The secrets are encrypted, with the encryption keys stored at ``/home/docker/crypt/keyset``. You should mount this on your host to preserve the keys: - -``` --v /somewhere/on/the/host:/home/docker/crypt/keyset -``` - -__Changes in this version__ -================= - -- 10.7 is no longer supported. -- Improved logging on errors. -- Improved user feedback during long operations (such as enabling FileVault). - -__Client__ -==== -The client is written in Pyobjc, and makes use of the built in fdesetup on OS X 10.8 and higher. An example login hook is provided to see how this could be implemented in your organisation. - -__Features__ -======= -- If escrow fails for some reason, the recovery key is stored on disk and a Launch Daemon will attempt to escrow the key periodically. -- If the app cannot contact the server, it can optionally quit. -- If FileVault is already enabled, the app will quit. - - - [1]: https://github.com/grahamgilbert/Crypt diff --git a/docker/django/management/__init__.py b/docker/django/management/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/docker/django/management/commands/__init__.py b/docker/django/management/commands/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/docker/django/management/commands/update_admin_user.py b/docker/django/management/commands/update_admin_user.py deleted file mode 100644 index 2532666..0000000 --- a/docker/django/management/commands/update_admin_user.py +++ /dev/null @@ -1,59 +0,0 @@ -""" -Creates an admin user if there aren't any existing superusers -""" - -from django.core.management.base import BaseCommand, CommandError -from django.contrib.auth.models import User -from optparse import make_option - - -class Command(BaseCommand): - help = "Creates/Updates an Admin user" - - # option_list = BaseCommand.option_list + ( - # make_option('--username', - # action='store', - # dest='username', - # default=None, - # help='Admin username'), - # ) + ( - # make_option('--password', - # action='store', - # dest='password', - # default=None, - # help='Admin password'), - # ) - def add_arguments(self, parser): - parser.add_argument( - "--username", - action="store", - dest="username", - default=None, - help="Admin username", - ) - - parser.add_argument( - "--password", - action="store", - dest="password", - default=None, - help="Admin password", - ) - - def handle(self, *args, **options): - username = options.get("username") - password = options.get("password") - if not username or not password: - raise StandardError("You must specify a username and password") - # Get the current superusers - su_count = User.objects.filter(is_superuser=True).count() - if su_count == 0: - # there aren't any superusers, create one - user, created = User.objects.get_or_create(username=username) - user.set_password(password) - user.is_staff = True - user.is_superuser = True - user.save() - print("{0} updated".format(username)) - else: - print("There are already {0} superusers".format(su_count)) diff --git a/docker/gunicorn_config.py b/docker/gunicorn_config.py deleted file mode 100644 index 6e6685d..0000000 --- a/docker/gunicorn_config.py +++ /dev/null @@ -1,6 +0,0 @@ -import multiprocessing - -bind = "0.0.0.0:8000" -workers = multiprocessing.cpu_count() * 2 + 1 -errorlog = "-" -accesslog = "-" diff --git a/docker/nginx/crypt.conf b/docker/nginx/crypt.conf deleted file mode 100644 index 66a2a9b..0000000 --- a/docker/nginx/crypt.conf +++ /dev/null @@ -1,24 +0,0 @@ -# Crypt.conf: -server { - listen 8000; - server_name crypt.local; - root /home/docker/crypt/static/; - # Redirect requests for static files - location /static/ { - alias /home/docker/crypt/static/; - } - - error_log /var/log/nginx/crypt-error.log warn; - - location / { - proxy_pass http://127.0.0.1:8001; - proxy_set_header X-Forwarded-Host $server_name; - proxy_set_header Host $http_host; - proxy_redirect off; - proxy_set_header X-Real-IP $remote_addr; - proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - add_header P3P 'CP="ALL DSP COR PSAa PSDa OUR NOR ONL UNI COM NAV"'; - port_in_redirect off; - add_header X-Frame-Options sameorigin; - } -} diff --git a/docker/nginx/nginx-env.conf b/docker/nginx/nginx-env.conf deleted file mode 100644 index 85149d3..0000000 --- a/docker/nginx/nginx-env.conf +++ /dev/null @@ -1,15 +0,0 @@ -# Environment Variables for settings.py -env DB_NAME; -env DB_USER; -env DB_PASS; -env DB_PORT_5432_TCP_ADDR; -env DB_PORT_5432_TCP_PORT; -env DEBIAN_FRONTEND; -env APP_DIR; -env DOCKER_CRYPT_TZ; -env DOCKER_CRYPT_ADMINS; -env DOCKER_CRYPT_ALLOWED; -env DOCKER_CRYPT_LANG; -env DOCKER_CRYPT_PLUGIN_ORDER; -env DOCKER_CRYPT_DISPLAY_NAME; -env APPNAME; diff --git a/docker/nginx/nginx.conf b/docker/nginx/nginx.conf deleted file mode 100644 index e29b35e..0000000 --- a/docker/nginx/nginx.conf +++ /dev/null @@ -1,99 +0,0 @@ -user www-data; -worker_processes 4; - -events { - worker_connections 19000; -} - -worker_rlimit_nofile 20000; -pid /run/nginx.pid; -daemon off; - -include /etc/nginx/main.d/*.conf; - -http { - - ## - # Basic Settings - ## - - sendfile on; - tcp_nopush on; - tcp_nodelay on; - keepalive_timeout 65; - types_hash_max_size 2048; - # server_tokens off; - - # server_names_hash_bucket_size 64; - # server_name_in_redirect off; - - include /etc/nginx/mime.types; - default_type application/octet-stream; - - ## - # Logging Settings - ## - - access_log /var/log/nginx/access.log; - error_log /var/log/nginx/error.log; - - ## - # Gzip Settings - ## - - gzip on; - gzip_disable "msie6"; - - # gzip_vary on; - # gzip_proxied any; - # gzip_comp_level 6; - # gzip_buffers 16 8k; - # gzip_http_version 1.1; - # gzip_types text/plain text/css application/json application/x-javascript text/xml application/xml application/xml+rss text/javascript; - - ## - # nginx-naxsi config - ## - # Uncomment it if you installed nginx-naxsi - ## - - # include /etc/nginx/naxsi_core.rules; - - ## - # Phusion Passenger config - ## - # Uncomment it if you installed passenger or passenger-enterprise - ## - - # passenger_root /usr/lib/ruby/vendor_ruby/phusion_passenger/locations.ini; - # passenger_ruby /usr/bin/ruby; - - ## - # Virtual Host Configs - ## - - include /etc/nginx/conf.d/*.conf; - include /etc/nginx/sites-enabled/*; -} - - -# mail { -# # See sample authentication script at: -# # http://wiki.nginx.org/ImapAuthenticateWithApachePhpScript -# -# # auth_http localhost/auth.php; -# # pop3_capabilities "TOP" "USER"; -# # imap_capabilities "IMAP4rev1" "UIDPLUS"; -# -# server { -# listen localhost:110; -# protocol pop3; -# proxy on; -# } -# -# server { -# listen localhost:143; -# protocol imap; -# proxy on; -# } -# } \ No newline at end of file diff --git a/docker/postgres.sh b/docker/postgres.sh deleted file mode 100755 index f6a54c4..0000000 --- a/docker/postgres.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash - -docker rm -f postgres-crypt - -docker run -d --name="postgres-crypt" \ - -e DB_NAME=crypt \ - -e DB_USER=admin \ - -e DB_PASS=password \ - -v /Users/Shared/test-pg-db:/var/lib/postgresql/data \ - -p 5432:5432 \ - grahamgilbert/postgres:9.4.5 - -sleep 30 diff --git a/docker/run.sh b/docker/run.sh deleted file mode 100755 index ddd7128..0000000 --- a/docker/run.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/sh - -set -e - -cd $APP_DIR -ADMIN_PASS=${ADMIN_PASS:-} -# python3 generate_keyczart.py -python3 manage.py migrate --noinput - -if [ ! -z "$ADMIN_PASS" ] ; then - python3 manage.py update_admin_user --username=admin --password=$ADMIN_PASS -else - python3 manage.py update_admin_user --username=admin --password=password -fi - - -export PYTHONPATH=$PYTHONPATH:$APP_DIR -export DJANGO_SETTINGS_MODULE='fvserver.settings' - -if [ "${DOCKER_CRYPT_DEBUG}" = "true" ] || [ "${DOCKER_CRYPT_DEBUG}" = "True" ] || [ "${DOCKER_CRYPT_DEBUG}" = "TRUE" ] ; then - echo "RUNNING IN DEBUG MODE" - python3 manage.py runserver 0.0.0.0:8000 -else - gunicorn -c $APP_DIR/gunicorn_config.py fvserver.wsgi:application -fi diff --git a/docker/run_docker.sh b/docker/run_docker.sh deleted file mode 100755 index 2eac9c8..0000000 --- a/docker/run_docker.sh +++ /dev/null @@ -1,14 +0,0 @@ -CWD=`pwd` -docker rm -f crypt - -docker build -t macadmins/crypt . -docker run -d \ - -e ADMIN_PASS=pass \ - -e DEBUG=false \ - -e PROMETHEUS=true \ - --name=crypt \ - --restart="always" \ - -v "$CWD/crypt.db":/home/docker/crypt/crypt.db \ - -e FIELD_ENCRYPTION_KEY=jKAv1Sde8m6jCYFnmps0iXkUfAilweNVjbvoebBrDwg= \ - -p 8000-8050:8000-8050 \ - macadmins/crypt diff --git a/docker/run_docker_postgres.sh b/docker/run_docker_postgres.sh deleted file mode 100755 index 046129b..0000000 --- a/docker/run_docker_postgres.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash - -CWD=`pwd` -docker rm -f crypt -docker build -t macadmins/crypt --no-cache . -docker run -d \ - -e ADMIN_PASS=pass \ - -e DEBUG=false \ - -e DB_NAME=crypt \ - -e DB_USER=admin \ - -e DB_PASS=password \ - --name=crypt \ - --link postgres-crypt:db \ - --restart="always" \ - -v "$CWD/keyset":/home/docker/crypt/keyset \ - -e FIELD_ENCRYPTION_KEY=jKAv1Sde8m6jCYFnmps0iXkUfAilweNVjbvoebBrDwg= \ - -p 8000-8050:8000-8050 \ - macadmins/crypt diff --git a/docker/settings.py b/docker/settings.py deleted file mode 100644 index b14dc8f..0000000 --- a/docker/settings.py +++ /dev/null @@ -1,89 +0,0 @@ -from fvserver.system_settings import * -from fvserver.settings_import import * -from django.utils.log import DEFAULT_LOGGING -import os - - -# Django settings for fvserver project. - -DATABASES = { - "default": { - "ENGINE": "django.db.backends.sqlite3", # Add 'postgresql_psycopg2', 'mysql', 'sqlite3' or 'oracle'. - "NAME": os.path.join( - PROJECT_DIR, "crypt.db" - ), # Or path to database file if using sqlite3. - "USER": "", # Not used with sqlite3. - "PASSWORD": "", # Not used with sqlite3. - "HOST": "", # Set to empty string for localhost. Not used with sqlite3. - "PORT": "", # Set to empty string for default. Not used with sqlite3. - } -} - -host = None -port = None - -if "DB_HOST" in os.environ: - host = os.environ.get("DB_HOST") - port = os.environ.get("DB_PORT") - -elif "DB_PORT_5432_TCP_ADDR" in os.environ: - host = os.environ.get("DB_PORT_5432_TCP_ADDR") - port = os.environ.get("DB_PORT_5432_TCP_PORT", "5432") - -if host and port: - DATABASES = { - "default": { - "ENGINE": "django.db.backends.postgresql_psycopg2", - "NAME": os.environ["DB_NAME"], - "USER": os.environ["DB_USER"], - "PASSWORD": os.environ["DB_PASS"], - "HOST": host, - "PORT": port, - } - } - -if "AWS_IAM" in os.environ: - import requests - - cert_bundle_url = ( - "https://truststore.pki.rds.amazonaws.com/global/global-bundle.pem" - ) - cert_target_path = "/etc/ssl/certs/global-bundle.pem" - - response = requests.get(cert_bundle_url) - if response.status_code == 200: - os.makedirs(os.path.dirname(cert_target_path), exist_ok=True) - - with open(cert_target_path, "wb") as file: - file.write(response.content) - print( - f"AWS RDS cert bundle successfully downloaded and saved to {cert_target_path}" - ) - else: - print( - f"Failed to download AWS RDS cert bundle, status code: {response.status_code}" - ) - DATABASES = { - "default": { - "ENGINE": "django_iam_dbauth.aws.postgresql", - "NAME": os.environ["DB_NAME"], - "USER": os.environ["DB_USER"], - "HOST": os.environ["DB_HOST"], - "PORT": os.environ["DB_PORT"], - "OPTIONS": { - "region_name": os.environ["AWS_RDS_REGION"], - "sslmode": "verify-full", - "sslrootcert": "/etc/ssl/certs/global-bundle.pem", - "use_iam_auth": True, - }, - } - } - -# Don't filter anything going to console -DEFAULT_LOGGING["handlers"]["console"]["filters"] = [] - -DEFAULT_LOGGING["loggers"][""] = { - "handlers": ["console"], - "level": "INFO", - "propagate": True, -} diff --git a/docker/settings_import.py b/docker/settings_import.py deleted file mode 100644 index 4ba097d..0000000 --- a/docker/settings_import.py +++ /dev/null @@ -1,121 +0,0 @@ -#!/usr/bin/env python -from os import getenv -import locale - -# Read the DEBUG setting from env var -try: - if getenv("DEBUG").lower() == "true": - DEBUG = True - else: - DEBUG = False -except: - DEBUG = False - -try: - if getenv("APPROVE_OWN").lower() == "false": - APPROVE_OWN = False - else: - APPROVE_OWN = True -except: - APPROVE_OWN = True - -try: - if getenv("ROTATE_VIEWED_SECRETS").lower() == "false": - ROTATE_VIEWED_SECRETS = False - else: - ROTATE_VIEWED_SECRETS = True -except: - ROTATE_VIEWED_SECRETS = True - -try: - if getenv("ALL_APPROVE").lower() == "true": - ALL_APPROVE = True - else: - ALL_APPROVE = False -except: - ALL_APPROVE = False - -# Read list of admins from $ADMINS env var -admin_list = [] -if getenv("ADMINS"): - admins_var = getenv("ADMINS") - if "," in admins_var and ":" in admins_var: - for admin in admins_var.split(":"): - admin_list.append(tuple(admin.split(","))) - ADMINS = admin_list - elif "," in admins_var: - admin_list.append(tuple(admins_var.split(","))) - ADMINS = admin_list -else: - ADMINS = [("Admin User", "admin@test.com")] - -# Read the preferred time zone from $TZ, use system locale or -# set to 'America/New_York' if neither are set -if getenv("TZ"): - if "/" in getenv("TZ"): - TIME_ZONE = getenv("TZ") - else: - TIME_ZONE = "America/New_York" -elif getenv("TZ"): - TIME_ZONE = getenv("TZ") -else: - TIME_ZONE = "America/New_York" - -# Read the preferred language code from $LANG & default to en-us if not set -# note django does not support locale-format for LANG -if getenv("LANG"): - LANGUAGE_CODE = getenv("LANG") -else: - LANGUAGE_CODE = "en-us" - -# Set the display name from the $DISPLAY_NAME env var, or -# use the default -if getenv("DISPLAY_NAME"): - DISPLAY_NAME = getenv("DISPLAY_NAME") -else: - DISPLAY_NAME = "Crypt" - -if getenv("EMAIL_HOST"): - EMAIL_HOST = getenv("EMAIL_HOST") - -if getenv("EMAIL_PORT"): - EMAIL_PORT = getenv("EMAIL_PORT") - -if getenv("EMAIL_USER"): - EMAIL_USER = getenv("EMAIL_USER") - -if getenv("EMAIL_PASSWORD"): - EMAIL_PASSWORD = getenv("EMAIL_PASSWORD") - -if getenv("CSRF_TRUSTED_ORIGINS"): - CSRF_TRUSTED_ORIGINS = getenv("CSRF_TRUSTED_ORIGINS").split(",") -else: - CSRF_TRUSTED_ORIGINS = [] - -if getenv("HOST_NAME"): - HOST_NAME = getenv("HOST_NAME") -else: - HOST_NAME = "https://cryptexample.com" - -if getenv("EMAIL_SENDER"): - EMAIL_SENDER = getenv("EMAIL_SENDER") -else: - EMAIL_SENDER = "crypt@cryptexample.com" - -# Read the list of allowed hosts from the $DOCKER_CRYPT_ALLOWED env var, or -# allow all hosts if none was set. -if getenv("ALLOWED_HOSTS"): - ALLOWED_HOSTS = getenv("ALLOWED_HOSTS").split(",") -else: - ALLOWED_HOSTS = ["*"] - -if getenv("SEND_EMAIL") and getenv("SEND_EMAIL").lower() == "true": - SEND_EMAIL = True -else: - SEND_EMAIL = False - -if getenv("EMAIL_USE_TLS") and getenv("EMAIL_USE_TLS").lower() == "true": - EMAIL_USE_TLS = True - -if getenv("EMAIL_USE_SSL") and getenv("EMAIL_USE_SSL").lower() == "true": - EMAIL_USE_SSL = True diff --git a/docker/setup_db.sh b/docker/setup_db.sh deleted file mode 100644 index d4d72de..0000000 --- a/docker/setup_db.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/sh -DB_NAME=crypt -DB_USER=admin -DB_PASS=password - -echo "CREATE ROLE $DB_USER WITH LOGIN ENCRYPTED PASSWORD '${DB_PASS}' CREATEDB;" | docker run \ - --rm \ - --interactive \ - --link postgres-crypt:postgres \ - postgres:9.3.4 \ - bash -c 'exec psql -h "$POSTGRES_PORT_5432_TCP_ADDR" -p "$POSTGRES_PORT_5432_TCP_PORT" -U postgres' - -echo "CREATE DATABASE $DB_NAME WITH OWNER $DB_USER TEMPLATE template0 ENCODING 'UTF8';" | docker run \ - --rm \ - --interactive \ - --link postgres-crypt:postgres \ - postgres:9.3.4 \ - bash -c 'exec psql -h "$POSTGRES_PORT_5432_TCP_ADDR" -p "$POSTGRES_PORT_5432_TCP_PORT" -U postgres' - -echo "GRANT ALL PRIVILEGES ON DATABASE $DB_NAME TO $DB_USER;" | docker run \ - --rm \ - --interactive \ - --link postgres-crypt:postgres \ - postgres:9.3.4 \ - bash -c 'exec psql -h "$POSTGRES_PORT_5432_TCP_ADDR" -p "$POSTGRES_PORT_5432_TCP_PORT" -U postgres' - - diff --git a/docker/wsgi.py b/docker/wsgi.py deleted file mode 100644 index dba5adf..0000000 --- a/docker/wsgi.py +++ /dev/null @@ -1,15 +0,0 @@ -import os, sys -import site - -CRYPT_ENV_DIR = "/home/docker/crypt" - -# Use site to load the site-packages directory of our virtualenv -site.addsitedir(os.path.join(CRYPT_ENV_DIR, "lib/python2.7/site-packages")) - -# Make sure we have the virtualenv and the Django app itself added to our path -sys.path.append(CRYPT_ENV_DIR) -sys.path.append(os.path.join(CRYPT_ENV_DIR, "fvserver")) -os.environ.setdefault("DJANGO_SETTINGS_MODULE", "fvserver.settings") -from django.core.wsgi import get_wsgi_application - -application = get_wsgi_application() diff --git a/docs/Development.md b/docs/Development.md index 84b8ebe..a1a40ec 100644 --- a/docs/Development.md +++ b/docs/Development.md @@ -1,10 +1,118 @@ -## Git setup +# Development -Add the following to `.git/hooks/pre-commit` and make executable +## Prerequisites +- Go 1.22 or later +- SQLite or PostgreSQL for database + +## Building + +```bash +# Build the server +make build + +# Build the migration tool +make cryptctl +``` + +## Running locally + +### Quick start with SQLite + +```bash +# Generate keys (save these for reuse) +export FIELD_ENCRYPTION_KEY=$(./cryptctl gen-key) +export SESSION_KEY=$(./cryptctl gen-key) + +# Run with SQLite +make run-sqlite +``` + +Or manually: + +```bash +export FIELD_ENCRYPTION_KEY=$(./cryptctl gen-key) +export SESSION_KEY=$(./cryptctl gen-key) +export SQLITE_PATH=./crypt.db + +./crypt-server +``` + +### With PostgreSQL + +```bash +export FIELD_ENCRYPTION_KEY=$(./cryptctl gen-key) +export SESSION_KEY=$(./cryptctl gen-key) +export DATABASE_URL="postgres://user:pass@localhost:5432/crypt" + +./crypt-server +``` + +### Create admin user + +```bash +./crypt-server -create-admin -username=admin -password='password' ``` -#!/bin/bash -ROOT=`git rev-parse --show-toplevel` -$ROOT/set_build_no.py -git add fvserver/version.plist + +## Testing + +```bash +# Run all tests +make test + +# Run tests with verbose output +go test -v ./... + +# Run tests for a specific package +go test -v ./internal/store/... +``` + +## Project structure + +``` +. +├── cmd/ +│ ├── crypt-server/ # Main server binary +│ └── cryptctl/ # Migration/utility tool +├── internal/ +│ ├── app/ # HTTP handlers, routing, sessions +│ ├── crypto/ # Encryption (AES-GCM) +│ ├── fixture/ # Migration fixture types +│ ├── migrate/ # Database migrations +│ └── store/ # Database layer (PostgreSQL, SQLite) +├── web/ +│ └── templates/ # HTML templates +└── docs/ # Documentation +``` + +## Database migrations + +Migrations are embedded in the binary and run automatically on startup. + +```bash +# Validate migrations +./crypt-server -validate-migrations + +# Print migrations +./crypt-server -print-migrations + +# Target specific driver +./crypt-server -validate-migrations -migrations-driver=postgres +``` + +Migration files are in `internal/migrate/migrations/{postgres,sqlite}/`. + +## Docker + +```bash +# Build image +docker build -t crypt-server . + +# Run +docker run -p 8080:8080 \ + -e FIELD_ENCRYPTION_KEY="..." \ + -e SESSION_KEY="..." \ + -e SQLITE_PATH=/data/crypt.db \ + -v ./data:/data \ + crypt-server ``` diff --git a/docs/Docker.md b/docs/Docker.md index cccccbe..fb4b79a 100644 --- a/docs/Docker.md +++ b/docs/Docker.md @@ -1,132 +1,197 @@ # Using Docker -## Server Initialization -This was last tested on Ubuntu 24.04 x86. This process may need to be modified for older installations. +## Quick Start -``` bash -git clone https://github.com/grahamgilbert/Crypt-Server.git -``` +### 1. Generate encryption keys -Install Docker and Docker Compose plugin following instructions here: -``` bash -https://docs.docker.com/engine/install/ubuntu/#install-using-the-repository -``` +```bash +# Generate field encryption key +docker run --rm ghcr.io/grahamgilbert/crypt-server ./cryptctl gen-key > field-encryption-key.txt -Restart the Docker services -``` bash -sudo systemctl restart docker +# Generate session key +docker run --rm ghcr.io/grahamgilbert/crypt-server ./cryptctl gen-key > session-key.txt ``` -Ensure docker permissions are set. Log out then back in after running this command: -``` bash -sudo usermod -aG docker $USER -``` +### 2. Create database file -## Prepare for first use -When starting from scratch, create a new empty file on the docker host to hold the sqlite3 secrets database -``` bash -touch /somewhere/else/on/the/host -``` +For SQLite (simplest setup): -## Basic usage -``` bash -docker run -d --name="Crypt" \ ---restart="always" \ --v /somewhere/else/on/the/host:/home/docker/crypt/crypt.db \ --e FIELD_ENCRYPTION_KEY='yourencryptionkey' \ --p 8000:8000 \ -macadmins/crypt-server +```bash +touch /path/to/crypt.db ``` -## Verify Operation -``` bash -docker logs Crypt +### 3. Run the container + +```bash +docker run -d --name="crypt" \ + --restart="always" \ + -v /path/to/crypt.db:/data/crypt.db \ + -e FIELD_ENCRYPTION_KEY="$(cat field-encryption-key.txt)" \ + -e SESSION_KEY="$(cat session-key.txt)" \ + -e SQLITE_PATH=/data/crypt.db \ + -p 8080:8080 \ + ghcr.io/grahamgilbert/crypt-server ``` -## Upgrading from Crypt Server 2 +### 4. Create admin user -The encryption method has changed in Crypt Server. You should pass in both your old encryption keys (e.g. `-v /somewhere/on/the/host:/home/docker/crypt/keyset`) and the new one (see below) for the first run to migrate your keys. After the migration you no longer need your old encryption keys. Crypt 3 is a major update, you should ensure any custom settings you pass are still valid. +```bash +docker exec crypt ./crypt-server \ + -create-admin \ + -username=admin \ + -password='your-secure-password' +``` +### 5. Verify operation +```bash +docker logs crypt +``` -The secrets are encrypted, with the encryption key passed in as an environment variable. You should back this up as the keys are not recoverable without them. +Access the web interface at `http://localhost:8080`. -### Generating an encryption key +## Using PostgreSQL -Run the following command to generate an encryption key (you should specify the string only): +For production deployments, PostgreSQL is recommended: -``` -docker run --rm -ti macadmins/crypt-server \ -python3 -c "from cryptography.fernet import Fernet; print(Fernet.generate_key())" +```bash +docker run -d --name="crypt" \ + --restart="always" \ + -e FIELD_ENCRYPTION_KEY="$(cat field-encryption-key.txt)" \ + -e SESSION_KEY="$(cat session-key.txt)" \ + -e DATABASE_URL="postgres://user:pass@db.example.com:5432/crypt" \ + -p 8080:8080 \ + ghcr.io/grahamgilbert/crypt-server ``` -## Backing up the database with a data dump -``` bash -docker exec -it Crypt bash -cd /home/docker/crypt/ -python manage.py dumpdata > db.json -exit -docker cp Crypt:/home/docker/crypt/db.json . +## Environment Variables + +### Required + +| Variable | Description | +|----------|-------------| +| `FIELD_ENCRYPTION_KEY` | Base64-encoded 32-byte key for encrypting secrets | +| `SESSION_KEY` | Random string (at least 32 bytes) for signing session cookies | + +### Database (one required) + +| Variable | Description | +|----------|-------------| +| `DATABASE_URL` | PostgreSQL connection string | +| `SQLITE_PATH` | Path to SQLite database file | + +### Optional + +| Variable | Default | Description | +|----------|---------|-------------| +| `SESSION_COOKIE_SECURE` | `false` | Set to `true` when using HTTPS | +| `SAML_CONFIG_FILE` | - | Path to SAML configuration YAML file | +| `APPROVE_OWN` | `false` | Allow users to approve their own requests | +| `ALL_APPROVE` | `false` | Grant all users approval permissions | +| `ROTATE_VIEWED_SECRETS` | `false` | Instruct clients to rotate secrets after viewing | + +## Docker Compose + +Create a `docker-compose.yml`: + +```yaml +services: + crypt: + image: ghcr.io/grahamgilbert/crypt-server + restart: always + ports: + - "8080:8080" + environment: + - FIELD_ENCRYPTION_KEY=${FIELD_ENCRYPTION_KEY} + - SESSION_KEY=${SESSION_KEY} + - SQLITE_PATH=/data/crypt.db + - SESSION_COOKIE_SECURE=true + volumes: + - ./data:/data ``` -Optionally -``` bash -docker exec -it Crypt bash -rm /home/docker/crypt/db.json -exit -``` - -## Using Postgres as an external database -Crypt, by default, uses a sqlite3 database for the django db backend. Crypt also supports using Postgres as the django db backend. If you would like to use an external Postgres server, you need to set the following environment variables: +Create a `.env` file: +```bash +FIELD_ENCRYPTION_KEY=your-base64-key-here +SESSION_KEY=your-session-key-here ``` -docker run -d --name="Crypt" \ ---restart="always" \ --p 8000:8000 \ --e DB_HOST='db.example.com' \ --e DB_PORT='5432' \ --e DB_NAME='postgres_dbname' \ --e DB_USER='postgres_user' \ --e DB_PASS='postgres_user_pass' \ --e FIELD_ENCRYPTION_KEY='yourencryptionkey' \ --e CSRF_TRUSTED_ORIGINS='https://FirstServer.com,https://SecondServer.com' \ -macadmins/crypt-server + +Run: + +```bash +mkdir -p data && touch data/crypt.db +docker compose up -d ``` -## Emails +## SSL/TLS + +Using Crypt without SSL **will** result in your secrets being compromised. Options: + +1. **Reverse proxy** (recommended): Use nginx, Caddy, or Traefik in front of Crypt +2. **Load balancer**: Terminate SSL at your load balancer + +Example with Caddy: + +```yaml +services: + crypt: + image: ghcr.io/grahamgilbert/crypt-server + environment: + - FIELD_ENCRYPTION_KEY=${FIELD_ENCRYPTION_KEY} + - SESSION_KEY=${SESSION_KEY} + - SQLITE_PATH=/data/crypt.db + - SESSION_COOKIE_SECURE=true + volumes: + - ./data:/data + + caddy: + image: caddy:2 + ports: + - "80:80" + - "443:443" + volumes: + - ./Caddyfile:/etc/caddy/Caddyfile + - caddy_data:/data + depends_on: + - crypt + +volumes: + caddy_data: +``` -If you would like Crypt to send emails when keys are requested and approved, you should set the following environment variables: +Example `Caddyfile`: ``` -docker run -d --name="Crypt" \ ---restart="always" \ --v /somewhere/on/the/host:/home/docker/crypt/keyset \ --v /somewhere/else/on/the/host:/home/docker/crypt/crypt.db \ --p 8000:8000 \ --e EMAIL_HOST='mail.yourdomain.com' \ --e EMAIL_PORT='25' \ --e EMAIL_USER='youruser' \ --e EMAIL_PASSWORD='yourpassword' \ --e HOST_NAME='https://crypt.myorg.com' \ --e FIELD_ENCRYPTION_KEY='yourencryptionkey' \ --e CSRF_TRUSTED_ORIGINS='https://FirstServer.com,https://SecondServer.com' \ -macadmins/crypt-server +crypt.example.com { + reverse_proxy crypt:8080 +} ``` -If your SMTP server doesn't need a setting (username and password for example), you should omit it. The `HOST_NAME` setting should be the hostname of your server - this will be used to generate links in emails. +## Backing Up -## SSL +### SQLite -It is recommended to use either an Nginx proxy in front of the Crypt app for SSL termination (outside of the scope of this document, see [here](https://www.digitalocean.com/community/tutorials/how-to-secure-nginx-with-let-s-encrypt-on-ubuntu-18-04) and [here](https://www.linode.com/docs/web-servers/nginx/use-nginx-reverse-proxy/) for more information), or to use Caddy. Caddy will also handle setting up letsencrypt SSL certificates for you. An example Caddyfile is included in `docker/Caddyfile`. Using Crypt without SSL __will__ result in your secrets being compromised. +```bash +# Stop container first for consistent backup +docker stop crypt +cp /path/to/crypt.db /path/to/backup/crypt.db.backup +docker start crypt +``` -_Note Caddy is only free for personal use. For commercial deployments you should build from source yourself or use Nginx._ +### PostgreSQL -## X-Frame-Options +Use standard PostgreSQL backup tools (`pg_dump`). -The nginx config included with the docker container configures the X-Frame-Options as sameorigin. This protects against a potential attacker using iframes to do bad stuff with Crypt. +### Important -Depending on your environment you may need to also configure X-Frame-Options on any proxies in front of Crypt. +Always back up your `FIELD_ENCRYPTION_KEY`. Secrets cannot be recovered without it. -## docker-compose +## Password Reset -An example `docker-compose.yml` is included. For basic usuage, you should only need to edit the `FIELD_ENCRYPTION_KEY`. +```bash +docker exec crypt ./crypt-server \ + -reset-password \ + -username=admin \ + -password='new-password' +``` diff --git a/docs/Installation_and_upgrade_on_Ubuntu_1804.md b/docs/Installation_and_upgrade_on_Ubuntu_1804.md deleted file mode 100644 index fcabc6d..0000000 --- a/docs/Installation_and_upgrade_on_Ubuntu_1804.md +++ /dev/null @@ -1,425 +0,0 @@ -Installation on Ubuntu 18.04.2 LTS -===================== -This document assumes a bare install of Ubuntu 18.04.2 LTS server. - -All commands should be run as root, unless specified - -##Install Prerequisites -###Install Apache and the Apache modules - - apt-get install apache2 libapache2-mod-wsgi-py3 - -###Install GCC (Needed for the encryption library) - - apt-get install gcc - -###Install git - - apt-get install git - -###Install the python C headers (so you can compile the encryption library) - - apt-get install python3.6-dev - -###If you want to use MySQL, you the following - - apt-get install libmysqlclient-dev - apt-get install python3-pip - pip3 install mysqlclient - -###Install the python dev tools - - apt-get install python3-setuptools - -###Verify virtual env is installed - - virtualenv --version - -###If is isn't, install it with - - apt-get install python3-venv - -##Create a non-admin service account and group -Create the Crypt user: - - useradd -d /usr/local/crypt_env cryptuser - -Create the Crypt group: - - groupadd cryptgroup - -Add cryptuser to the cryptgroup group: - - usermod -g cryptgroup cryptuser - -##Create the virtual environment -When a virtualenv is created, pip will also be installed to manage a -virtualenv's local packages. Create a virtualenv which will handle -installing Django in a contained environment. In this example we'll -create a virtualenv for Crypt at /usr/local. This should be run from -Bash, as this is what the virtualenv activate script expects. - -Go to where we're going to install the virtualenv: - - cd /usr/local - -Create the virtualenv for Crypt: - - python3 -m venv crypt_env - -Make sure cryptuser has permissions to the new crypt_env folder: - - chown -R cryptuser crypt_env - -Switch to the service account: - - su cryptuser - -Virtualenv needs to be run from a bash prompt, so let's switch to one: - - bash - -Now we can activate the virtualenv: - - cd crypt_env - source bin/activate - -##Install and configure Crypt -Still inside the crypt_env virtualenv, use git to clone the current -version of Crypt-Server - - git clone https://github.com/grahamgilbert/Crypt-Server.git crypt - -Now we need to get the other components for Crypt - - pip3 install -r crypt/setup/requirements.txt - -Now we need to generate some encryption keys (make sure these go in crypt/keyset): - - cd crypt - python3 -c "from cryptography.fernet import Fernet; print(Fernet.generate_key())" - -You will need the key output here to be set as the variable FIELD_ENCRYPTION_KEY, in the settings.py file below. -Next we need to make a copy of the example_settings.py file and put -in your info: - - cd ./fvserver - cp example_settings.py settings.py - -Edit settings.py: - -* Set FIELD_ENCRYPTION_KEY to the encryption key generated above -* Set ADMINS to an administrative name and email -* Set TIME_ZONE to the appropriate timezone -* Change ALLOWED_HOSTS to be a list of hosts that the server will be -accessible from (e.g. ``ALLOWED_HOSTS=['crypt.grahamgilbert.dev']`` - -If you wish to use email notifications, add the following to your settings.py: - -``` python -# This is the host and port you are sending email on -EMAIL_HOST = 'localhost' -EMAIL_PORT = '25' - -# If your email server requires Authentication -EMAIL_HOST_USER = 'youruser' -EMAIL_HOST_PASSWORD = 'yourpassword' -# This is the URL at the front of any links in the emails -HOST_NAME = 'http://localhost' -``` - -## Using with MySQL -In order to use Crypt-Server with MySQL, you need to configure it to connect to -a MySQL server instead of the default sqlite3. To do this, locate the DATABASES -section of settings.py, and change ENGINE to 'django.db.backends.mysql'. Set the -NAME as the database name, USER and PASSWORD to your user and password, and -either leave HOST as blank for localhost, or insert an IP or hostname of your -MySQL server. You will also need to install the correct python and apt packages. - - apt-get install libmysqlclient-dev - apt-get install pip - pip3 install mysqlclient - -## More Setup -We need to use Django's manage.py to initialise the app's database and -create an admin user. Running the syncdb command will ask you to create -an admin user - make sure you do this! - - cd .. - python3 manage.py syncdb - python3 manage.py migrate - -Stage the static files (type yes when prompted) - - python3 manage.py collectstatic - -##Installing mod_wsgi and configuring Apache -To run Crypt in a production environment, we need to set up a suitable -webserver. Make sure you exit out of the crypt_env virtualenv and the -cryptuser user (back to root) before continuing). - -##Set up an Apache virtualhost -You will probably need to edit most of these bits to suit your -environment, especially to add SSL encryption. There are many different -options, especially if you prefer nginx, the below example is for apache -with an internal puppet CA. Make a new file at -/etc/apache2/sites-available (call it whatever you want) - - vim /etc/apache2/sites-available/crypt.conf - -And then enter something like: - - - ServerName crypt.yourdomain.com - WSGIScriptAlias / /usr/local/crypt_env/crypt/crypt.wsgi - WSGIDaemonProcess cryptuser user=cryptuser group=cryptgroup - Alias /static/ /usr/local/crypt_env/crypt/static/ - SSLEngine on - SSLCertificateFile "/etc/puppetlabs/puppet/ssl/certs/cryptserver.yourdomain.com.pem" - SSLCertificateKeyFile "/etc/puppetlabs/puppet/ssl/private_keys/cryptserver.yourdomain.com.pem" - SSLCACertificatePath "/etc/puppetlabs/puppet/ssl/certs" - SSLCACertificateFile "/etc/puppetlabs/puppet/ssl/certs/ca.pem" - SSLCARevocationFile "/etc/puppetlabs/puppet/ssl/crl.pem" - SSLProtocol +TLSv1 - SSLCipherSuite ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-SHA384:ECDHE-RSA-AES128-SHA256:ECDHE-RSA-AES256-SHA:ECDHE-RSA-AES128-SHA:DHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-SHA256:DHE-RSA-AES128-SHA256:DHE-RSA-AES256-SHA:DHE-RSA-AES128-SHA - SSLHonorCipherOrder On - - WSGIProcessGroup cryptuser - WSGIApplicationGroup %{GLOBAL} - Options FollowSymLinks - AllowOverride None - Require all granted - - - WSGISocketPrefix /var/run/wsgi - WSGIPythonHome /usr/local/crypt_env - -Now we just need to enable our site, and then your can go and configure -your clients: - - a2ensite crypt.conf - service apache2 reload - - -Upgrade on Ubuntu 18.04.2 LTS from Crypt 2 to Crypt 3 -===================== -This document assumes that you have Ubuntu 18.04.2 LTS with Python 2 and non upgraded versions of Apache etc.. used for Crypt 2 installs. - -All commands should be run as root, unless specified - -##Upgrade Prerequisites -###Upgrade Apache and the Apache modules. This is critical. If you do not update apache WSGI to compile against python 3, your site will not load. - - apt-get update - apt-get upgrade - - apt-get install apache2 libapache2-mod-wsgi-py3 - -###Install GCC (Needed for the encryption library) - - apt-get update gcc - -###Install git - - apt-get update git - -###Install the python C headers (so you can compile the encryption library) - - apt-get install python3.6-dev - -###If you want to use MySQL, you the following - - apt-get update libmysqlclient-dev - apt-get install python3-pip - pip3 install mysqlclient - -###Install the python dev tools - - apt-get install python3-setuptools - -###Verify virtual env is installed - - virtualenv --version - -###If is isn't, install it with - - apt-get install python3-venv - -##Create a non-admin service account and group - -We are assuming that you already have a user if you were running Crypt 2 so if you need to create a new user refer to the above core install instructions otherwise skip. - -##Update the virtual environment -When a virtualenv is created, pip will also be installed to manage a -virtualenv's local packages. Create a virtualenv which will handle -installing Django in a contained environment. In this example we'll -create a virtualenv for Crypt at /usr/local. This should be run from -Bash, as this is what the virtualenv activate script expects. - -For the update this process simply rebuilds the virtual environment. It will -not overwrite it completely nor the files inside it. - -Go to where we're going to install the virtualenv: - - cd /usr/local - -Create the virtualenv for Crypt: - - python3 -m venv crypt_env - -Make sure cryptuser has permissions to the new crypt_env folder: - - chown -R cryptuser crypt_env - -Switch to the service account: - - su cryptuser - -Virtualenv needs to be run from a bash prompt, so let's switch to one: - - bash - -Now we can activate the virtualenv: - - cd crypt_env - source bin/activate - - -##Update and configure Crypt -Still inside the crypt_env virtualenv, use git to clone the current -version of Crypt-Server - - cd crypt - git pull - -Now we need to get the other components for Crypt - - pip3 install -r setup/requirements.txt - -Now we need to generate some encryption keys (make sure these go in crypt/keyset): - - python3 -c "from cryptography.fernet import Fernet; print(Fernet.generate_key())" - -You will need the key output here to be set as the variable FIELD_ENCRYPTION_KEY, in the settings.py file below. -Next we need to make a copy of the example_settings.py file and put -in your info: - - cd ./fvserver - nano settings.py - -There are 2 blocks that have changed - -MIDDLEWARE_SCRIPTS has changed to MIDDLEWARE and a new variable STATICFILES_STORAGE has been added underneath that block as shown below. - - MIDDLEWARE = [ - "django.middleware.security.SecurityMiddleware", - "whitenoise.middleware.WhiteNoiseMiddleware", - "django.contrib.sessions.middleware.SessionMiddleware", - "django.middleware.common.CommonMiddleware", - "django.middleware.csrf.CsrfViewMiddleware", - "django.contrib.auth.middleware.AuthenticationMiddleware", - "django.contrib.messages.middleware.MessageMiddleware", - "django.middleware.clickjacking.XFrameOptionsMiddleware", - ] - - STATICFILES_STORAGE = "whitenoise.storage.CompressedManifestStaticFilesStorage" - -TEMPLATES should be replaced fully by - - TEMPLATES = [ - { - "BACKEND": "django.template.backends.django.DjangoTemplates", - "DIRS": [ - # insert your TEMPLATE_DIRS here - os.path.join(PROJECT_DIR, "templates") - ], - "APP_DIRS": True, - "OPTIONS": { - "context_processors": [ - # Insert your TEMPLATE_CONTEXT_PROCESSORS here or use this - # list if you haven't customized them: - "django.contrib.auth.context_processors.auth", - "django.template.context_processors.debug", - "django.template.context_processors.i18n", - "django.template.context_processors.media", - "django.template.context_processors.static", - "django.template.context_processors.tz", - "django.contrib.messages.context_processors.messages", - "fvserver.context_processors.crypt_version", - ], - "debug": DEBUG, - }, - } - ] - -Finally INSTALLED_APPS has been updated as follows - - INSTALLED_APPS = ( - "whitenoise.runserver_nostatic", - "django.contrib.auth", - "django.contrib.contenttypes", - "django.contrib.sessions", - "django.contrib.sites", - "django.contrib.messages", - "django.contrib.staticfiles", - # Uncomment the next line to enable the admin: - "django.contrib.admin", - # Uncomment the next line to enable admin documentation: - "django.contrib.admindocs", - "server", - "bootstrap4", - "django_extensions", - ) - -Edit settings.py: - -* Set FIELD_ENCRYPTION_KEY to the encryption key generated above -* Set ADMINS to an administrative name and email -* Set TIME_ZONE to the appropriate timezone -* Change ALLOWED_HOSTS to be a list of hosts that the server will be -accessible from (e.g. ``ALLOWED_HOSTS=['crypt.grahamgilbert.dev']`` - -## Update the WSGI -The WSGI is hard coded with version 2.7 of Python and it needs to be modified before the WSGI will load. - - nano /usr/local/crypt_env/crypt/crypt.wsgi - -Modify the code to reflect the current version of your python 3. In this case I am using 3.6. - - site.addsitedir(os.path.join(CRYPT_ENV_DIR, 'lib/python3.6/site-packages')) - - -## Using with MySQL -In order to use Crypt-Server with MySQL, you need to configure it to connect to -a MySQL server instead of the default sqlite3. To do this, locate the DATABASES -section of settings.py, and change ENGINE to 'django.db.backends.mysql'. Set the -NAME as the database name, USER and PASSWORD to your user and password, and -either leave HOST as blank for localhost, or insert an IP or hostname of your -MySQL server. You will also need to install the correct python and apt packages. - - apt-get install libmysqlclient-dev - apt-get install pip - pip3 install mysqlclient - -## More Setup -We need to use Django's manage.py to initialise the app's database and -create an admin user. Running the syncdb command will ask you to create -an admin user - make sure you do this! - - cd .. - python3 manage.py migrate - -Stage the static files (type yes when prompted) - - python3 manage.py collectstatic - -##Installing mod_wsgi and configuring Apache -To run Crypt in a production environment, we need to set up a suitable -webserver. Make sure you exit out of the crypt_env virtualenv and the -cryptuser user (back to root) before continuing). - - service apache2 reload - - or - - service apache2 restart \ No newline at end of file diff --git a/docs/Installation_on_CentOS_7.md b/docs/Installation_on_CentOS_7.md deleted file mode 100644 index aa050bc..0000000 --- a/docs/Installation_on_CentOS_7.md +++ /dev/null @@ -1,194 +0,0 @@ -# Installation on CentOS 7 - -This document has not been updated for several years and should only be used for version 2 of Crypt server. Pull requests to update this are gratefully accepted. - -All commands should be run as root, unless specified. - -## Install Prerequisites - -### Setup and Virtual Environment - -Install needed packages: - -`yum install git python-setuptools gcc libffi-devel python-devel openssl-devel -postgresql-libs postgresql-devel` - -Check if `virtualenv` is installed via `virtualenv --version` and install it if -needed: - -`easy_install virtualenv` - -### Create a non-admin service account and group - -Create a new group: - -`groupadd cryptgroup` - -and add a new user in the cryptgroup with a home directory: - -`useradd -g cryptgroup -m cryptuser` - -### Create the virtual environment - -When a virtualenv is created, pip will also be installed to manage a -virtualenv's local packages. Create a virtualenv which will handle installing -Django in a contained environment. In this example we'll create a virtualenv for -Crypt at /usr/local. This should be run from Bash, as this is what the -virtualenv activate script expects. - -Switch to bash if needed: `/usr/bin/bash` and get into the local folder: - -`cd /usr/local` - -Create the virtialenv for Crypt `virtualenv crypt_env` and change folder -permissions: `chown -R cryptuser:cryptgroup crypt_env`. - -Switch to the newly created service account `su cryptuser` and make sure to use -the bash shell: `bash`. - -Now let's activate the virtualenv: - -``` -cd crypt_env -source bin/activate -``` - -### Copy and configure Crypt - -Still inside the crypt_env virtualenv, use git to clone the current version of -Crypt-Server: - -`git clone https://github.com/grahamgilbert/Crypt-Server.git crypt` - - -We could also get the 1.6.8 version via git without touching -the `requirements.txt`-file: `pip install git+https://github.com/django-extensions/django-extensions@243abe93451c3b53a5f562023afcd809b79c9b7f`. - -Also install these aditional packages: - -``` -pip install psycopg2==2.5.3 -pip install gunicorn -pip install setproctitle -``` - -Now we need to get the other missing components for Crypt via pip: - -`pip install -r crypt/setup/requirements.txt` - -Now we need to generate some encryption keys (dont forget to change directory!): - -``` -cd crypt -python ./generate_keyczart.py -``` - -Next we need to make a copy of the example_settings.py file and put in your -info: - -``` -cd fvserver -cp example_settings.py settings.py -vim settings.py -``` - -Atleast change the following: -- Set ADMINS to an administrative name and email -- Set TIME_ZONE to the appropriate timezone -- Change ALLOWED_HOSTS to be a list of hosts that the server will be accessible - from. -- Take a look at the `DATABASES` and email settings. - -### DB Setup - -We need to use Django's manage.py to initialise the app's database and create an -admin user. Running the syncdb command will ask you to create an admin user - -make sure you do this! - -``` -cd .. -python manage.py syncdb -python manage.py migrate -``` - -If you used an external DB like Postgres you dont need to run `pyton manage.py syncdb`. - -And stage the static files (type yes when prompted): - -``` -python manage.py collectstatic -``` - -Also create a new superuser to auth on the webinterface: - -``` -python manage.py createsuperuser --username $USERNAME -``` - -## Set up an Apache virtualhost - -Exit out of the virtualenv and also switch back to root user. After that install -the Apache Modification `mod_wsgi`: `yum install mod_wsgi`. - -Create the wsgi directory and give the cryptuser the needed rights: - -Create a new VirtualHost `vim /etc/httpd/conf.d/crypt.conf`: - -``` - - ServerName crypt.yourdomain.com - WSGIScriptAlias / /home/app/crypt_env/crypt/crypt.wsgi - WSGIDaemonProcess cryptuser user=cryptuser group=cryptgroup - Alias /static/ /home/app/crypt_env/crypt/static/ - SSLEngine on - SSLCertificateFile "/etc/puppetlabs/puppet/ssl/certs/cryptserver.yourdomain.com.pem" - SSLCertificateKeyFile "/etc/puppetlabs/puppet/ssl/private_keys/cryptserver.yourdomain.com.pem" - SSLCACertificatePath "/etc/puppetlabs/puppet/ssl/certs" - SSLCACertificateFile "/etc/puppetlabs/puppet/ssl/certs/ca.pem" - SSLCARevocationFile "/etc/puppetlabs/puppet/ssl/crl.pem" - SSLProtocol +TLSv1 - SSLCipherSuite ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-SHA384:ECDHE-RSA-AES128-SHA256:ECDHE-RSA-AES256-SHA:ECDHE-RSA-AES128-SHA:DHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-SHA256:DHE-RSA-AES128-SHA256:DHE-RSA-AES256-SHA:DHE-RSA-AES128-SHA - SSLHonorCipherOrder On - - WSGIProcessGroup cryptuser - WSGIApplicationGroup %{GLOBAL} - Require all granted - - -WSGISocketPrefix /var/run/wsgi -WSGIPythonHome /home/app/crypt_env -``` - -### Configure SELinux to work with Apache - -On CentOS SELinux is activated and needs to be configured so Apache can do it's work: - -``` -yum install -y policycoreutils-python -semanage fcontext -a -t httpd_sys_content_t "/usr/local/crypt_env/crypt(/.*)?" -semanage fcontext -a -t httpd_sys_rw_content_t "/usr/local/crypt_env/crypt(/.*)?" -setsebool -P httpd_can_sendmail on -setsebool -P httpd_can_network_connect_db on -restorecon -Rv /usr/local/crypt_env/crypt -``` - -If you enabled SSL also grant access to the key files: - -``` -semanage fcontext -a -t httpd_sys_rw_content_t "/etc/pki/tls/private/KEY.key" -restorecon -Rv /etc/pki/tls/private/KEY.key -``` - -### Open needed ports in the firewall - -``` -firewall-cmd --zone=public --add-service=https --permanent -firewall-cmd --reload -``` - -### Activate Apache and start the httpd-server - -``` -systemctl enable httpd -systemctl start httpd -``` diff --git a/docs/Installation_on_Ubuntu_1404.md b/docs/Installation_on_Ubuntu_1404.md deleted file mode 100644 index 3b7f876..0000000 --- a/docs/Installation_on_Ubuntu_1404.md +++ /dev/null @@ -1,197 +0,0 @@ -Installation on Ubuntu 14.04 LTS -===================== -This document assumes a bare install of Ubuntu 14.04 LTS server. This document has not been updated for Crypt Server 3.0. Pull requests to update this (and to update to Ubuntu 18) are gratefully accepted. - -All commands should be run as root, unless specified - -##Install Prerequisites -###Install Apache and the Apache modules - - apt-get install apache2 libapache2-mod-wsgi - -###Install GCC (Needed for the encryption library) - - apt-get install gcc - -###Install git - - apt-get install git - -###Install the python C headers (so you can compile the encryption library) - - apt-get install python-dev - -###If you want to use MySQL, you the following - - apt-get install libmysqlclient-dev python-mysqldb mysql-client - -###Install the python dev tools - - apt-get install python-setuptools - -###Verify virtual env is installed - - virtualenv --version - -###If is isn't, install it with - - easy_install virtualenv - -##Create a non-admin service account and group -Create the Crypt user: - - useradd -d /usr/local/crypt_env cryptuser - -Create the Crypt group: - - groupadd cryptgroup - -Add cryptuser to the cryptgroup group: - - usermod -g cryptgroup cryptuser - -##Create the virtual environment -When a virtualenv is created, pip will also be installed to manage a -virtualenv's local packages. Create a virtualenv which will handle -installing Django in a contained environment. In this example we'll -create a virtualenv for Crypt at /usr/local. This should be run from -Bash, as this is what the virtualenv activate script expects. - -Go to where we're going to install the virtualenv: - - cd /usr/local - -Create the virtualenv for Crypt: - - virtualenv crypt_env - -Make sure cryptuser has permissions to the new crypt_env folder: - - chown -R cryptuser crypt_env - -Switch to the service account: - - su cryptuser - -Virtualenv needs to be run from a bash prompt, so let's switch to one: - - bash - -Now we can activate the virtualenv: - - cd crypt_env - source bin/activate - -##Install and configure Crypt -Still inside the crypt_env virtualenv, use git to clone the current -version of Crypt-Server - - git clone https://github.com/grahamgilbert/Crypt-Server.git crypt - -Now we need to get the other components for Crypt - - pip install -r crypt/setup/requirements.txt - -Now we need to generate some encryption keys (make sure these go in crypt/keyset): - - cd crypt - python ./generate_keyczart.py - -Next we need to make a copy of the example_settings.py file and put -in your info: - - cd ./fvserver - cp example_settings.py settings.py - -Edit settings.py: - -* Set ADMINS to an administrative name and email -* Set TIME_ZONE to the appropriate timezone -* Change ALLOWED_HOSTS to be a list of hosts that the server will be -accessible from (e.g. ``ALLOWED_HOSTS=['crypt.grahamgilbert.dev']`` - -If you wish to use email notifications, add the following to your settings.py: - -``` python -# This is the host and port you are sending email on -EMAIL_HOST = 'localhost' -EMAIL_PORT = '25' - -# If your email server requires Authentication -EMAIL_HOST_USER = 'youruser' -EMAIL_HOST_PASSWORD = 'yourpassword' -# This is the URL at the front of any links in the emails -HOST_NAME = 'http://localhost' -``` - -## Using with MySQL -In order to use Crypt-Server with MySQL, you need to configure it to connect to -a MySQL server instead of the default sqlite3. To do this, locate the DATABASES -section of settings.py, and change ENGINE to 'django.db.backends.mysql'. Set the -NAME as the database name, USER and PASSWORD to your user and password, and -either leave HOST as blank for localhost, or insert an IP or hostname of your -MySQL server. You will also need to install the correct python and apt packages. - - apt-get install libmysqlclient-dev python-dev mysql-client - pip install mysql-python - - -## More Setup -We need to use Django's manage.py to initialise the app's database and -create an admin user. Running the syncdb command will ask you to create -an admin user - make sure you do this! - - cd .. - python manage.py syncdb - python manage.py migrate - -Stage the static files (type yes when prompted) - - python manage.py collectstatic - -##Installing mod_wsgi and configuring Apache -To run Crypt in a production environment, we need to set up a suitable -webserver. Make sure you exit out of the crypt_env virtualenv and the -cryptuser user (back to root) before continuing). - -##Set up an Apache virtualhost -You will probably need to edit most of these bits to suit your -environment, especially to add SSL encryption. There are many different -options, especially if you prefer nginx, the below example is for apache -with an internal puppet CA. Make a new file at -/etc/apache2/sites-available (call it whatever you want) - - vim /etc/apache2/sites-available/crypt.conf - -And then enter something like: - - - ServerName crypt.yourdomain.com - WSGIScriptAlias / /usr/local/crypt_env/crypt/crypt.wsgi - WSGIDaemonProcess cryptuser user=cryptuser group=cryptgroup - Alias /static/ /usr/local/crypt_env/crypt/static/ - SSLEngine on - SSLCertificateFile "/etc/puppetlabs/puppet/ssl/certs/cryptserver.yourdomain.com.pem" - SSLCertificateKeyFile "/etc/puppetlabs/puppet/ssl/private_keys/cryptserver.yourdomain.com.pem" - SSLCACertificatePath "/etc/puppetlabs/puppet/ssl/certs" - SSLCACertificateFile "/etc/puppetlabs/puppet/ssl/certs/ca.pem" - SSLCARevocationFile "/etc/puppetlabs/puppet/ssl/crl.pem" - SSLProtocol +TLSv1 - SSLCipherSuite ECDHE-RSA-AES256-GCM-SHA384:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-SHA384:ECDHE-RSA-AES128-SHA256:ECDHE-RSA-AES256-SHA:ECDHE-RSA-AES128-SHA:DHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-SHA256:DHE-RSA-AES128-SHA256:DHE-RSA-AES256-SHA:DHE-RSA-AES128-SHA - SSLHonorCipherOrder On - - WSGIProcessGroup cryptuser - WSGIApplicationGroup %{GLOBAL} - Options FollowSymLinks - AllowOverride None - Require all granted - - - WSGISocketPrefix /var/run/wsgi - WSGIPythonHome /usr/local/crypt_env - -Now we just need to enable our site, and then your can go and configure -your clients: - - a2ensite crypt.conf - service apache2 reload diff --git a/docs/migration-plan.md b/docs/migration-plan.md new file mode 100644 index 0000000..0ac03ff --- /dev/null +++ b/docs/migration-plan.md @@ -0,0 +1,141 @@ +# Migration Plan: Django -> Go Backend + New Frontend + +## Goal +Replicate the existing Django app behavior and UX exactly, while migrating to a Go backend and a new frontend (technology TBD). A migration tool will move data into a new schema with a new encryption scheme for secret fields. + +## Current App Parity Targets + +### Routes and behaviors +Source: `server/urls.py`, `server/views.py`, `fvserver/urls.py` + +- `/` home +- `/ajax/` DataTables JSON +- `/new/computer/` add computer +- `/new/secret//` add secret for computer +- `/info/secret//` secret details and request status +- `/info//` computer details +- `/request//` request secret retrieval +- `/retrieve//` retrieve secret if approved +- `/approve//` approve or deny request +- `/manage-requests/` list outstanding requests +- `/verify///` escrow verification +- `/checkin/` client escrow endpoint +- `/login/`, `/logout/`, password change routes via Django auth + +### Data model +Source: `server/models.py` + +- `Computer` + - `serial` (unique) + - `username` + - `computername` + - `last_checkin` +- `Secret` + - `computer` (FK) + - `secret` (encrypted) + - `secret_type` (`recovery_key`, `password`, `unlock_pin`) + - `date_escrowed` + - `rotation_required` +- `Request` + - `secret` (FK) + - `requesting_user` (FK -> User) + - `approved` (null/true/false) + - `auth_user` (approver) + - `reason_for_request` + - `reason_for_approval` + - `date_requested` + - `date_approved` + - `current` (bool) +- Uses Django `User`, groups, and `can_approve` permission + +### Workflow behavior +Source: `server/views.py`, `fvserver/system_settings.py` + +- Request approval and denial flow; pending/approved/denied states. +- Self-approval optional gating (`APPROVE_OWN`). +- Global approver permission option (`ALL_APPROVE`). +- Cleanup: requests older than 7 days after approval are set `current=false`. +- Secret rotation signaling on retrieval (`ROTATE_VIEWED_SECRETS`). +- Emails for requests and approvals (if `SEND_EMAIL`). +- `HOST_NAME` used for link generation. + +### UI/UX parity +Source: `server/templates/server/*.html`, `templates/*.html`, `site_static/*` + +- Server-rendered pages using Bootstrap + DataTables. +- DataTables search/sort/pagination for the home list. +- Request, approve, retrieve, and manage screens. +- Login + password change flows (for local users). +- CSRF protections on all user input +- UI that allows admin users to create, edit, delete users and reset passwords. +- Utility in the main app binary to create the first admin user +- UI should look exactly the same as the existing django app +- For SAML users, isStaff or can approve permissions should be able to be set via saml attributes + +## Migration Plan + +### 1) Parity Spec and Contract Definition +- Enumerate and document every endpoint, request payload, response shape, and status code. +- Capture UI flows + required forms from templates. +- Freeze feature flags and configuration semantics: + - `APPROVE_OWN`, `ALL_APPROVE`, `ROTATE_VIEWED_SECRETS` + - Email behavior (`SEND_EMAIL`, `EMAIL_*`) + - `HOST_NAME` + +### 2) Go Backend Architecture (Design Only) +- **Auth**: prefer SAML-first with local reset (avoid importing Django hashes). + - Data flags: `local_login_enabled` (tenant/user), `must_reset_password`, `password_hash` nullable, optional `auth_source` (`saml`, `local`, `hybrid`). + - Migration: import identity fields only; set `password_hash=null`, `must_reset_password=true`; default `local_login_enabled=false` for SAML-only tenants. + - UX: show SAML by default; local login only if enabled; if local login and `must_reset_password=true`, force reset flow before password auth. + - Reset flow: admin-only resets set password + clear `must_reset_password`; no email sent; enforce strong password policy. + - Admin actions: log admin-driven password resets and forced reset toggles (who/when/target user/IP). + - Provide an admin UI to view these audit logs. + - Add `must_reset_password` support for next-login resets (used during local-user migration). +- **Permissions**: replicate `can_approve` semantics and group assignment logic. +- **Endpoints**: provide exact behavior parity (including `checkin`/`verify` JSON). +- **Cleanup job**: scheduled cleanup for expired requests. +- **Email notifications**: remove entirely (no request/approval emails). +- **Audit/logging**: maintain approval/request metadata and event logs. + - Admin audit events: password resets and forced reset toggles with fields `actor`, `timestamp`, `target_user`, `ip_address`, `action`, and `reason` (if provided). + - Admin UI: add an audit log view for these events (read-only). + +### 3) Data Model Mapping + Migration Tool +- **Extract**: dump from Django DB; decrypt `Secret.secret` with existing `FIELD_ENCRYPTION_KEY`. +- **Transform**: map old schema to new schema, preserving IDs and foreign keys where possible. +- **Re-encrypt**: apply new encryption scheme for secret fields. +- **Users and permissions**: migrate users, group membership, and `can_approve`. + - Optional: migration tool accepts a password mapping file to set initial local passwords (when provided). + - Proposed format: CSV with `username_or_email,password,must_reset_password` (last column optional). + - Behavior: if an entry exists, set the password; `must_reset_password` defaults to `false` for mapped entries unless explicitly set. + - Users not in the file default to `must_reset_password=true`. +- **Validation**: + - Record counts by table. + - Referential integrity checks. + - Spot-check decrypt -> re-encrypt correctness. + - Validate `/verify/` and `/checkin/` behavior on migrated data. + + + Previous requqests and approvals should also be migrated. + +### 4) Frontend Plan (Tech TBD) +- Option A: **Server-rendered HTML** with Go templates, reusing existing HTML structure and CSS/JS assets for perfect parity. + +- Ensure that all approval/pending/retrieve states match current UI semantics. + +### 5) Cutover and Rollback Strategy +- **Phase 1**: Freeze or dual-write, export data. +- **Phase 2**: Staging validation against the parity spec. +- **Phase 3**: Production cutover with snapshot and rollback plan. + +## Open Questions / Decisions Needed +- Frontend direction (server-rendered vs SPA)? + - server rendered +- Auth strategy (import Django hashes vs reset or SSO)? + - answered above +- Target database for Go (Postgres vs other)? + - postgres + - sqlite +- Email delivery provider requirements? + - scrap email completely +- Any changes desired in request cleanup timing or rotation semantics? + - keep as-is diff --git a/docs/saml-config.sample.yaml b/docs/saml-config.sample.yaml new file mode 100644 index 0000000..f557bf5 --- /dev/null +++ b/docs/saml-config.sample.yaml @@ -0,0 +1,74 @@ +# SAML Configuration for Crypt Server +# +# Required fields: +# - root_url: The public URL where the app is accessible +# - idp_metadata_path OR idp_metadata_url: Path to IdP metadata XML file or URL to fetch it +# +# Certificate/key are optional unless sign_request is true or your IdP encrypts assertions + +# Required: The public URL where the application is accessible +root_url: https://crypt.example.com + +# Optional: Custom entity ID (defaults to root_url + metadata_url_path) +entity_id: https://crypt.example.com/saml2/metadata/ + +# IdP Metadata - provide ONE of these: +idp_metadata_path: /path/to/idp/metadata.xml +# idp_metadata_url: https://your-idp.example.com/metadata + +# Optional: SP certificate and private key +# Only required if sign_request is true or your IdP encrypts SAML assertions +# Most IdPs (including Okta) don't require these for basic setup +# certificate_path: /path/to/sp.crt +# private_key_path: /path/to/sp.key + +# Optional: Sign SAML authentication requests (requires certificate_path and private_key_path) +sign_request: false + +# Optional: Allow IdP-initiated SSO (user clicks app in IdP portal) +allow_idp_initiated: true + +# Username extraction - tries these in order: +# 1. NameID (if use_name_id_as_username is true) +# 2. username_attribute +# 3. attribute_mapping with value "username" +# 4. "uid" attribute +use_name_id_as_username: true +username_attribute: uid + +# Optional: Map SAML attributes to user fields +attribute_mapping: + uid: username + mail: email + +# Group-based permissions +# The attribute containing group membership (default: memberOf) +groups_attribute: memberOf + +# Groups that grant staff (admin UI access) permissions +staff_groups: + - Crypt-Admins + +# Groups that grant superuser permissions (staff + approve) +superuser_groups: [] + +# Groups that grant approval permissions +can_approve_groups: + - Crypt-Approvers + - Crypt-Admins + +# Auto-create users on first SAML login +create_unknown_user: true + +# Default settings for SAML-created users +auth_source: saml +local_login_enabled: false +must_reset_password: false + +# Redirect after successful login +default_redirect_uri: / + +# SAML endpoint paths (defaults shown) +metadata_url_path: /saml2/metadata/ +acs_url_path: /saml2/acs/ +slo_url_path: /saml2/ls/ diff --git a/functional_tests/__init__.py b/functional_tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/functional_tests/base.py b/functional_tests/base.py deleted file mode 100644 index 3867fd2..0000000 --- a/functional_tests/base.py +++ /dev/null @@ -1,34 +0,0 @@ -from django.contrib.staticfiles.testing import StaticLiveServerTestCase -from django.contrib.auth.models import User -from selenium import webdriver -from server.models import Computer, Secret -from datetime import datetime - - -class FunctionalTest(StaticLiveServerTestCase): - def setUp(self): - self.browser = webdriver.Firefox() - User.objects.create_superuser("admin", "a@a.com", "sekrit") - User.objects._create_user( - "tech", "t@a.com", "password", is_staff=True, is_superuser=False - ) - tech_test_computer = Computer(serial="TECHSERIAL") - tech_test_computer.username = "Daft Tech" - tech_test_computer.computername = "compy587" - tech_test_computer.save() - secret = Secret( - computer=tech_test_computer, - secret="SHHH-DONT-TELL", - date_escrowed=datetime.now(), - ) - secret.save() - - def tearDown(self): - self.browser.quit() - - # currently doesn't work to find entered elements - # def check_for_row_in_list_table(self, row_value): - # table = self.browser.find_element_by_id('id_list_table') - # rows = table.find_elements_by_tag_name('tr') - # value = rows.find_elements_by_tag_name('td') - # self.assertIn(row_value, [value.text for row in rows]) diff --git a/functional_tests/test_simple_site_functionality.py b/functional_tests/test_simple_site_functionality.py deleted file mode 100644 index e7470d5..0000000 --- a/functional_tests/test_simple_site_functionality.py +++ /dev/null @@ -1,87 +0,0 @@ -import time -from .base import FunctionalTest -from selenium.webdriver.common.keys import Keys - - -class LoginAndBasicFunctionality(FunctionalTest): - def test_admin_can_create_and_browse(self): - # Admin goes to fv2 key mgmt site, sees it's named Crypt post-redirect to a login - self.browser.get(self.live_server_url) - self.assertIn("Crypt", self.browser.title) - username_box = self.browser.find_element_by_id("id_username") - password_box = self.browser.find_element_by_id("id_password") - username_box.send_keys("admin") - password_box.send_keys("sekrit") - password_box.send_keys(Keys.ENTER) - time.sleep(1) - # After putting in creds, admin can create a computer from the hamburger menu, and is redirected to details - self.browser.find_element_by_id("dLabel").click() - self.browser.find_element_by_link_text("New computer").click() - inputbox = self.browser.find_element_by_id("id_serial") - self.assertEqual(inputbox.get_attribute("placeholder"), "Serial Number") - inputbox.send_keys("MYSERIAL") - username = self.browser.find_element_by_id("id_username") - username.send_keys("Mr. Admin") - computername = self.browser.find_element_by_id("id_computername") - computername.send_keys("compy486") - self.browser.find_element_by_css_selector("button.btn.btn-primary").click() - detail_url = self.browser.current_url - self.assertRegexpMatches(detail_url, "/info/.+") - # When viewing details of computer, admin can create a secret for it - self.browser.find_element_by_class_name("dropdown-toggle").click() - self.browser.find_element_by_css_selector( - "span.glyphicon.glyphicon-plus" - ).click() - secretbox = self.browser.find_element_by_name("secret") - self.assertEqual(secretbox.get_attribute("placeholder"), "Secret") - secretbox.send_keys("LICE-NSEP-LATE") - self.browser.find_element_by_css_selector("button.btn.btn-primary").click() - # The newly created secret shows up on the page, and you can click info - self.browser.find_element_by_css_selector("a.btn.btn-info.btn-xs").click() - # You're taken to the secret's info page, and you can start a request and provide a reason - self.browser.find_element_by_css_selector("a.btn.btn-large.btn-info").click() - requestbox = self.browser.find_element_by_name("reason_for_request") - self.assertEqual(requestbox.get_attribute("placeholder"), "Reason for request") - requestbox.send_keys("Pretty Please Gimme") - self.browser.find_element_by_css_selector( - "button.btn.primary.btn-default" - ).click() - # As the admin is all-powerful, they are automatically approved and can find the secret in the page text - key = self.browser.find_element_by_tag_name("code").text - self.assertEqual(key, "LICE-NSEP-LATE") - - def test_standard_user_can_request_and_admin_can_approve(self): - # Standard tech user can log in and finds previously-created computer+secret - self.browser.get(self.live_server_url) - self.assertIn("Crypt", self.browser.title) - username_box = self.browser.find_element_by_id("id_username") - password_box = self.browser.find_element_by_id("id_password") - username_box.send_keys("tech") - password_box.send_keys("password") - password_box.send_keys(Keys.ENTER) - self.browser.find_element_by_link_text("Info").click() - self.browser.find_element_by_link_text("Info / Request").click() - secret_url = self.browser.current_url - self.assertRegexpMatches(secret_url, "/info/secret/.+") - self.browser.find_element_by_link_text("Request Key").click() - requestbox = self.browser.find_element_by_name("reason_for_request") - requestbox.send_keys("With sugar on top") - self.browser.find_element_by_css_selector( - "button.btn.primary.btn-default" - ).click() - # Standard users live in a world ruled by gravity, and must wait for approval - disabled_button = self.browser.find_element_by_css_selector( - "button.btn.btn-disabled.btn-info" - ).text - self.assertEqual(disabled_button, "Request Pending") - # Let's log out and let the admin do their approval magic - self.browser.find_element_by_id("dLabel").click() - self.browser.find_element_by_link_text("Log out").click() - username_box = self.browser.find_element_by_id("id_username") - password_box = self.browser.find_element_by_id("id_password") - username_box.send_keys("admin") - password_box.send_keys("sekrit") - password_box.send_keys(Keys.ENTER) - self.browser.find_element_by_link_text("Approve requests").click() - # This should fail, as per https://github.com/grahamgilbert/Crypt-Server/issues/12 - self.browser.find_element_by_link_text("Manage").click() diff --git a/fvserver/__init__.py b/fvserver/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/fvserver/context_processors.py b/fvserver/context_processors.py deleted file mode 100644 index ddfca9c..0000000 --- a/fvserver/context_processors.py +++ /dev/null @@ -1,12 +0,0 @@ -import plistlib -import os - - -def crypt_version(request): - # return the value you want as a dictionary. you may add multiple values in there. - current_dir = os.path.dirname(os.path.realpath(__file__)) - with open( - os.path.join(os.path.dirname(current_dir), "fvserver", "version.plist"), "rb" - ) as f: - version = plistlib.load(f) - return {"CRYPT_VERSION": version["version"]} diff --git a/fvserver/example_settings.py b/fvserver/example_settings.py deleted file mode 100755 index f38f599..0000000 --- a/fvserver/example_settings.py +++ /dev/null @@ -1,15 +0,0 @@ -import os -from fvserver.system_settings import * - -DATABASES = { - "default": { - "ENGINE": "django.db.backends.sqlite3", # Add 'postgresql_psycopg2', 'mysql', 'sqlite3' or 'oracle'. - "NAME": os.path.join( - PROJECT_DIR, "crypt.db" - ), # Or path to database file if using sqlite3. - "USER": "", # Not used with sqlite3. - "PASSWORD": "", # Not used with sqlite3. - "HOST": "", # Set to empty string for localhost. Not used with sqlite3. - "PORT": "", # Set to empty string for default. Not used with sqlite3. - } -} diff --git a/fvserver/system_settings.py b/fvserver/system_settings.py deleted file mode 100644 index 0acc033..0000000 --- a/fvserver/system_settings.py +++ /dev/null @@ -1,186 +0,0 @@ -import os - -# Django settings for fvserver project. - -PROJECT_DIR = os.path.abspath( - os.path.join(os.path.dirname(os.path.abspath(__file__)), os.path.pardir) -) -ENCRYPTED_FIELD_KEYS_DIR = os.path.join(PROJECT_DIR, "keyset") -DEBUG = False - -ROTATE_VIEWED_SECRETS = True - -DATE_FORMAT = "Y-m-d H:i:s" -DATETIME_FORMAT = "Y-m-d H:i:s" - -ADMINS = [ - ( - # ('Your Name', 'your_email@example.com'), - ) -] - -FIELD_ENCRYPTION_KEY = os.environ.get("FIELD_ENCRYPTION_KEY", "") - -MANAGERS = ADMINS - -# Local time zone for this installation. Choices can be found here: -# http://en.wikipedia.org/wiki/List_of_tz_zones_by_name -# although not all choices may be available on all operating systems. -# In a Windows environment this must be set to your system time zone. -TIME_ZONE = "Europe/London" - -# Language code for this installation. All choices can be found here: -# http://www.i18nguy.com/unicode/language-identifiers.html -LANGUAGE_CODE = "en-us" - -SITE_ID = 1 - -# If you set this to False, Django will make some optimizations so as not -# to load the internationalization machinery. -USE_I18N = True - -# If you set this to False, Django will not format dates, numbers and -# calendars according to the current locale. -USE_L10N = False - -# If you set this to False, Django will not use timezone-aware datetimes. -USE_TZ = True - -# Absolute filesystem path to the directory that will hold user-uploaded files. -# Example: "/home/media/media.lawrence.com/media/" -MEDIA_ROOT = "" - -# URL that handles the media served from MEDIA_ROOT. Make sure to use a -# trailing slash. -# Examples: "http://media.lawrence.com/media/", "http://example.com/media/" -MEDIA_URL = "" - -# Absolute path to the directory static files should be collected to. -# Don't put anything in this directory yourself; store your static files -# in apps' "static/" subdirectories and in STATICFILES_DIRS. -# Example: "/home/media/media.lawrence.com/static/" -STATIC_ROOT = os.path.join(PROJECT_DIR, "static") - -# URL prefix for static files. -# Example: "http://media.lawrence.com/static/" -STATIC_URL = "/static/" - -# URL prefix for admin static files -- CSS, JavaScript and images. -# Make sure to use a trailing slash. -# Examples: "http://foo.com/static/admin/", "/static/admin/". -# deprecated in Django 1.4, but django_wsgiserver still looks for it -# when serving admin media -ADMIN_MEDIA_PREFIX = "/static_admin/" - -# Additional locations of static files -STATICFILES_DIRS = ( - # Put strings here, like "/home/html/static" or "C:/www/django/static". - # Always use forward slashes, even on Windows. - # Don't forget to use absolute paths, not relative paths. - os.path.join(PROJECT_DIR, "site_static"), -) - -LOGIN_URL = "/login/" -LOGIN_REDIRECT_URL = "/" - -ALLOWED_HOSTS = ["*"] - -# List of finder classes that know how to find static files in -# various locations. -STATICFILES_FINDERS = ( - "django.contrib.staticfiles.finders.FileSystemFinder", - "django.contrib.staticfiles.finders.AppDirectoriesFinder", - # 'django.contrib.staticfiles.finders.DefaultStorageFinder', -) - -# Make this unique, and don't share it with anybody. -SECRET_KEY = "6%y8=x5(#ufxd*+d+-ohwy0b$5z^cla@7tvl@n55_h_cex0qat" - -TEMPLATES = [ - { - "BACKEND": "django.template.backends.django.DjangoTemplates", - "DIRS": [ - # insert your TEMPLATE_DIRS here - os.path.join(PROJECT_DIR, "templates") - ], - "APP_DIRS": True, - "OPTIONS": { - "context_processors": [ - # Insert your TEMPLATE_CONTEXT_PROCESSORS here or use this - # list if you haven't customized them: - "django.contrib.auth.context_processors.auth", - "django.template.context_processors.debug", - "django.template.context_processors.i18n", - "django.contrib.messages.context_processors.messages", - "django.template.context_processors.media", - "django.template.context_processors.static", - "django.template.context_processors.tz", - "django.template.context_processors.request", - "fvserver.context_processors.crypt_version", - ], - "debug": DEBUG, - }, - } -] - -MIDDLEWARE = [ - "django.middleware.security.SecurityMiddleware", - "whitenoise.middleware.WhiteNoiseMiddleware", - "django.contrib.sessions.middleware.SessionMiddleware", - "django.middleware.common.CommonMiddleware", - "django.middleware.csrf.CsrfViewMiddleware", - "django.contrib.auth.middleware.AuthenticationMiddleware", - "django.contrib.messages.middleware.MessageMiddleware", - "django.middleware.clickjacking.XFrameOptionsMiddleware", -] - - -ROOT_URLCONF = "fvserver.urls" - -# Python dotted path to the WSGI application used by Django's runserver. -WSGI_APPLICATION = "fvserver.wsgi.application" - - -INSTALLED_APPS = ( - "whitenoise.runserver_nostatic", - "django.contrib.auth", - "django.contrib.contenttypes", - "django.contrib.sessions", - "django.contrib.sites", - "django.contrib.messages", - "django.contrib.staticfiles", - # Uncomment the next line to enable the admin: - "django.contrib.admin", - # Uncomment the next line to enable admin documentation: - "django.contrib.admindocs", - "server", - "bootstrap4", - "django_extensions", -) - -LOGGING = { - "version": 1, - "disable_existing_loggers": False, - "formatters": { - "default": { - "format": "[DJANGO] %(levelname)s %(asctime)s %(module)s " - "%(name)s.%(funcName)s:%(lineno)s: %(message)s" - }, - }, - "handlers": { - "console": { - "level": "DEBUG", - "class": "logging.StreamHandler", - "formatter": "default", - } - }, - "loggers": { - "*": { - "handlers": ["console"], - "level": "DEBUG", - "propagate": True, - } - }, -} - -DEFAULT_AUTO_FIELD = "django.db.models.AutoField" diff --git a/fvserver/urls.py b/fvserver/urls.py deleted file mode 100644 index 8eef191..0000000 --- a/fvserver/urls.py +++ /dev/null @@ -1,31 +0,0 @@ -# from django.conf.urls import include, url - -# Uncomment the next two lines to enable the admin: -from django.contrib import admin - -# admin.autodiscover() -import django.contrib.auth.views as auth_views -import django.contrib.admindocs.urls as admindocs_urls -from django.urls import path, include - -app_name = "fvserver" - -urlpatterns = [ - path("login/", auth_views.LoginView.as_view(), name="login"), - path("logout/", auth_views.logout_then_login, name="logout"), - path( - "changepassword/", - auth_views.PasswordChangeView.as_view(), - name="password_change", - ), - path( - "changepassword/done/", - auth_views.PasswordChangeDoneView.as_view(), - name="password_change_done", - ), - path("", include("server.urls")), - # Uncomment the admin/doc line below to enable admin documentation: - path("admin/doc/", include(admindocs_urls)), - # Uncomment the next line to enable the admin: - path("admin/", admin.site.urls), -] diff --git a/fvserver/version.plist b/fvserver/version.plist deleted file mode 100644 index d03c1d7..0000000 --- a/fvserver/version.plist +++ /dev/null @@ -1,8 +0,0 @@ - - - - - version - 3.4.1.378 - - diff --git a/fvserver/wsgi.py b/fvserver/wsgi.py deleted file mode 100644 index 1924e61..0000000 --- a/fvserver/wsgi.py +++ /dev/null @@ -1,30 +0,0 @@ -""" -WSGI config for fvserver project. - -This module contains the WSGI application used by Django's development server -and any production WSGI deployments. It should expose a module-level variable -named ``application``. Django's ``runserver`` and ``runfcgi`` commands discover -this application via the ``WSGI_APPLICATION`` setting. - -Usually you will have the standard Django WSGI application here, but it also -might make sense to replace the whole Django WSGI application with a custom one -that later delegates to the Django one. For example, you could introduce WSGI -middleware here, or combine a Django application with an application of another -framework. - -""" - -import os - -os.environ.setdefault("DJANGO_SETTINGS_MODULE", "fvserver.settings") - -# This application object is used by any WSGI server configured to use this -# file. This includes Django's development server, if the WSGI_APPLICATION -# setting points here. -from django.core.wsgi import get_wsgi_application - -application = get_wsgi_application() - -# Apply WSGI middleware here. -# from helloworld.wsgi import HelloWorldApplication -# application = HelloWorldApplication(application) diff --git a/generate_keyczart.py b/generate_keyczart.py deleted file mode 100644 index b8b022e..0000000 --- a/generate_keyczart.py +++ /dev/null @@ -1,17 +0,0 @@ -import keyczar -import subprocess -import os - -directory = os.path.join(os.path.dirname(os.path.abspath(__file__)), "keyset") - -if not os.path.exists(directory): - os.makedirs(directory) - -if not os.listdir(directory): - location_string = "--location={}".format(directory) - cmd = ["keyczart", "create", location_string, "--purpose=crypt", "--name=crypt"] - subprocess.check_call(cmd) - cmd = ["keyczart", "addkey", location_string, "--status=primary"] - subprocess.check_call(cmd) -else: - print("Keyset directory already has something in there. Skipping key generation.") diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..f6de7d9 --- /dev/null +++ b/go.mod @@ -0,0 +1,40 @@ +module crypt-server + +go 1.21 + +require ( + github.com/DATA-DOG/go-sqlmock v1.5.2 + github.com/crewjam/saml v0.4.13 + github.com/fernet/fernet-go v0.0.0-20240119011108-303da6aec611 + github.com/golang-jwt/jwt/v4 v4.4.3 + github.com/lib/pq v1.10.9 + github.com/mattn/go-sqlite3 v1.14.33 + github.com/stretchr/testify v1.9.0 + golang.org/x/crypto v0.23.0 + gopkg.in/yaml.v3 v3.0.1 + modernc.org/sqlite v1.30.0 +) + +require ( + github.com/beevik/etree v1.1.0 // indirect + github.com/crewjam/httperr v0.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect + github.com/jonboulle/clockwork v0.2.2 // indirect + github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/ncruces/go-strftime v0.1.9 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/russellhaering/goxmldsig v1.2.0 // indirect + golang.org/x/sys v0.20.0 // indirect + modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect + modernc.org/libc v1.50.9 // indirect + modernc.org/mathutil v1.6.0 // indirect + modernc.org/memory v1.8.0 // indirect + modernc.org/strutil v1.2.0 // indirect + modernc.org/token v1.1.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..6fb3a2f --- /dev/null +++ b/go.sum @@ -0,0 +1,129 @@ +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/beevik/etree v1.1.0 h1:T0xke/WvNtMoCqgzPhkX2r4rjY3GDZFi+FjpRZY2Jbs= +github.com/beevik/etree v1.1.0/go.mod h1:r8Aw8JqVegEf0w2fDnATrX9VpkMcyFeM0FhwO62wh+A= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/crewjam/httperr v0.2.0 h1:b2BfXR8U3AlIHwNeFFvZ+BV1LFvKLlzMjzaTnZMybNo= +github.com/crewjam/httperr v0.2.0/go.mod h1:Jlz+Sg/XqBQhyMjdDiC+GNNRzZTD7x39Gu3pglZ5oH4= +github.com/crewjam/saml v0.4.13 h1:TYHggH/hwP7eArqiXSJUvtOPNzQDyQ7vwmwEqlFWhMc= +github.com/crewjam/saml v0.4.13/go.mod h1:igEejV+fihTIlHXYP8zOec3V5A8y3lws5bQBFsTm4gA= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dchest/uniuri v1.2.0/go.mod h1:fSzm4SLHzNZvWLvWJew423PhAzkpNQYq+uNLq4kxhkY= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/fernet/fernet-go v0.0.0-20240119011108-303da6aec611 h1:JwYtKJ/DVEoIA5dH45OEU7uoryZY/gjd/BQiwwAOImM= +github.com/fernet/fernet-go v0.0.0-20240119011108-303da6aec611/go.mod h1:zHMNeYgqrTpKyjawjitDg0Osd1P/FmeA0SZLYK3RfLQ= +github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU= +github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo= +github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/jonboulle/clockwork v0.2.2 h1:UOGuzwb1PwsrDAObMuhUnj0p5ULPj8V/xJ7Kx9qUBdQ= +github.com/jonboulle/clockwork v0.2.2/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU= +github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0= +github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= +github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= +github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/russellhaering/goxmldsig v1.2.0 h1:Y6GTTc9Un5hCxSzVz4UIWQ/zuVwDvzJk80guqzwx6Vg= +github.com/russellhaering/goxmldsig v1.2.0/go.mod h1:gM4MDENBQf7M+V824SGfyIUVFWydB7n0KkEubVJl+Tw= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/zenazn/goji v1.0.1/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +golang.org/x/crypto v0.0.0-20220128200615-198e4374d7ed/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= +golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= +golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= +gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= +modernc.org/cc/v4 v4.21.2 h1:dycHFB/jDc3IyacKipCNSDrjIC0Lm1hyoWOZTRR20Lk= +modernc.org/cc/v4 v4.21.2/go.mod h1:HM7VJTZbUCR3rV8EYBi9wxnJ0ZBRiGE5OeGXNA0IsLQ= +modernc.org/ccgo/v4 v4.17.8 h1:yyWBf2ipA0Y9GGz/MmCmi3EFpKgeS7ICrAFes+suEbs= +modernc.org/ccgo/v4 v4.17.8/go.mod h1:buJnJ6Fn0tyAdP/dqePbrrvLyr6qslFfTbFrCuaYvtA= +modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE= +modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ= +modernc.org/gc/v2 v2.4.1 h1:9cNzOqPyMJBvrUipmynX0ZohMhcxPtMccYgGOJdOiBw= +modernc.org/gc/v2 v2.4.1/go.mod h1:wzN5dK1AzVGoH6XOzc3YZ+ey/jPgYHLuVckd62P0GYU= +modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 h1:5D53IMaUuA5InSeMu9eJtlQXS2NxAhyWQvkKEgXZhHI= +modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6/go.mod h1:Qz0X07sNOR1jWYCrJMEnbW/X55x206Q7Vt4mz6/wHp4= +modernc.org/libc v1.50.9 h1:hIWf1uz55lorXQhfoEoezdUHjxzuO6ceshET/yWjSjk= +modernc.org/libc v1.50.9/go.mod h1:15P6ublJ9FJR8YQCGy8DeQ2Uwur7iW9Hserr/T3OFZE= +modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4= +modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo= +modernc.org/memory v1.8.0 h1:IqGTL6eFMaDZZhEWwcREgeMXYwmW83LYW8cROZYkg+E= +modernc.org/memory v1.8.0/go.mod h1:XPZ936zp5OMKGWPqbD3JShgd/ZoQ7899TUuQqxY+peU= +modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4= +modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= +modernc.org/sortutil v1.2.0 h1:jQiD3PfS2REGJNzNCMMaLSp/wdMNieTbKX920Cqdgqc= +modernc.org/sortutil v1.2.0/go.mod h1:TKU2s7kJMf1AE84OoiGppNHJwvB753OYfNl2WRb++Ss= +modernc.org/sqlite v1.30.0 h1:8YhPUs/HTnlEgErn/jSYQTwHN/ex8CjHHjg+K9iG7LM= +modernc.org/sqlite v1.30.0/go.mod h1:cgkTARJ9ugeXSNaLBPK3CqbOe7Ec7ZhWPoMFGldEYEw= +modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA= +modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/internal/app/admin_ui_test.go b/internal/app/admin_ui_test.go new file mode 100644 index 0000000..6597f12 --- /dev/null +++ b/internal/app/admin_ui_test.go @@ -0,0 +1,243 @@ +package app + +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAdminUserListUI(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + passwordHash := hashPasswordForTest(t, "password") + _, err := memStore.AddUser("testuser", passwordHash, false, true, true, false, "local") + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/admin/users/", nil, "admin") + serveProtected(server, rec, req, server.handleUserList) + + require.Equal(t, http.StatusOK, rec.Code) + body := rec.Body.String() + require.Contains(t, body, "Users") + require.Contains(t, body, "New User") + require.Contains(t, body, "admin") + require.Contains(t, body, "testuser") + require.Contains(t, body, "/admin/users/new/") +} + +func TestAdminNewUserUI(t *testing.T) { + server, _, sessionManager := newTestServer(t) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/admin/users/new/", nil, "admin") + serveProtected(server, rec, req, server.handleNewUser) + + require.Equal(t, http.StatusOK, rec.Code) + body := rec.Body.String() + require.Contains(t, body, "New User") + require.Contains(t, body, "Username") + require.Contains(t, body, "Password") + require.Contains(t, body, "Enable local login") + require.Contains(t, body, "Admin user") + require.Contains(t, body, "Can approve requests") + require.Contains(t, body, "Auth source") +} + +func TestAdminEditUserUI(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + passwordHash := hashPasswordForTest(t, "password") + user, err := memStore.AddUser("edituser", passwordHash, false, false, true, false, "local") + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/admin/users/"+intToString(user.ID)+"/edit/", nil, "admin") + serveProtected(server, rec, req, server.handleUserEdit) + + require.Equal(t, http.StatusOK, rec.Code) + body := rec.Body.String() + require.Contains(t, body, "Edit User") + require.Contains(t, body, "edituser") + require.Contains(t, body, "Admin user") + require.Contains(t, body, "Can approve requests") + require.Contains(t, body, "Local login enabled") + require.Contains(t, body, "Save Changes") + require.Contains(t, body, "Back") +} + +func TestAdminResetPasswordUI(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + passwordHash := hashPasswordForTest(t, "password") + user, err := memStore.AddUser("resetuser", passwordHash, false, false, true, false, "local") + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/admin/users/"+intToString(user.ID)+"/password/", nil, "admin") + serveProtected(server, rec, req, server.handleUserPassword) + + require.Equal(t, http.StatusOK, rec.Code) + body := rec.Body.String() + require.Contains(t, body, "Reset Password") + require.Contains(t, body, "resetuser") + require.Contains(t, body, "New password") + require.Contains(t, body, "Reset Password") +} + +func TestAdminDeleteUserUI(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + passwordHash := hashPasswordForTest(t, "password") + user, err := memStore.AddUser("deleteuser", passwordHash, false, false, true, false, "local") + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/admin/users/"+intToString(user.ID)+"/delete/", nil, "admin") + serveProtected(server, rec, req, server.handleUserDelete) + + require.Equal(t, http.StatusOK, rec.Code) + body := rec.Body.String() + require.Contains(t, body, "Delete User") + require.Contains(t, body, "deleteuser") + require.Contains(t, body, "Are you sure") + require.Contains(t, body, "Delete User") + require.Contains(t, body, "Cancel") +} + +func TestAdminAuditLogUI(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + _, err := memStore.AddAuditEvent("admin", "testuser", "user_created", "test reason", "127.0.0.1") + require.NoError(t, err) + _, err = memStore.AddAuditEvent("admin", "testuser2", "password_reset", "", "192.168.1.1") + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/admin/audit/", nil, "admin") + serveProtected(server, rec, req, server.handleAuditLog) + + require.Equal(t, http.StatusOK, rec.Code) + body := rec.Body.String() + require.Contains(t, body, "Audit Log") + require.Contains(t, body, "Search audit log") + require.Contains(t, body, "admin") + require.Contains(t, body, "testuser") + require.Contains(t, body, "user_created") + require.Contains(t, body, "password_reset") + require.Contains(t, body, "127.0.0.1") + require.Contains(t, body, "192.168.1.1") +} + +func TestAdminAuditLogSearch(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + _, err := memStore.AddAuditEvent("admin", "user1", "user_created", "", "127.0.0.1") + require.NoError(t, err) + _, err = memStore.AddAuditEvent("admin", "user2", "password_reset", "", "127.0.0.1") + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/admin/audit/?q=password", nil, "admin") + serveProtected(server, rec, req, server.handleAuditLog) + + require.Equal(t, http.StatusOK, rec.Code) + body := rec.Body.String() + require.Contains(t, body, "password_reset") + require.Contains(t, body, "Clear") +} + +func TestAdminAuditLogPagination(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + // Create more than 50 events to test pagination + for i := 0; i < 55; i++ { + _, err := memStore.AddAuditEvent("admin", "user", "user_created", "", "127.0.0.1") + require.NoError(t, err) + } + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/admin/audit/", nil, "admin") + serveProtected(server, rec, req, server.handleAuditLog) + + require.Equal(t, http.StatusOK, rec.Code) + body := rec.Body.String() + require.Contains(t, body, "Page 1 of 2") + require.Contains(t, body, "Next") +} + +func TestAdminUINavigation(t *testing.T) { + server, _, sessionManager := newTestServer(t) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/", nil, "admin") + serveProtected(server, rec, req, server.handleIndex) + + require.Equal(t, http.StatusOK, rec.Code) + body := rec.Body.String() + require.Contains(t, body, "/admin/users/") + require.Contains(t, body, "/admin/audit/") + require.Contains(t, body, "Users") + require.Contains(t, body, "Audit Log") +} + +func TestAdminUIHiddenForNonStaffUsers(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + passwordHash := hashPasswordForTest(t, "password") + _, err := memStore.AddUser("regularuser", passwordHash, false, false, true, false, "local") + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/", nil, "regularuser") + serveProtected(server, rec, req, server.handleIndex) + + require.Equal(t, http.StatusOK, rec.Code) + body := rec.Body.String() + require.NotContains(t, body, "/admin/users/") + require.NotContains(t, body, "Audit Log") +} + +func TestCreateUserWithAllPermissions(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + + form := url.Values{} + form.Set("username", "fulluser") + form.Set("password", "Str0ng!Passw0rd") + form.Set("local_login_enabled", "on") + form.Set("is_staff", "on") + form.Set("can_approve", "on") + form.Set("auth_source", "local") + + rec := httptest.NewRecorder() + req := newAuthenticatedFormRequest(t, server, sessionManager, http.MethodPost, "/admin/users/new/", form, "admin") + serveProtected(server, rec, req, server.handleNewUser) + + require.Equal(t, http.StatusSeeOther, rec.Code) + + user, err := memStore.GetUserByUsername("fulluser") + require.NoError(t, err) + require.True(t, user.IsStaff) + require.True(t, user.CanApprove) + require.True(t, user.LocalLoginEnabled) + require.False(t, user.MustResetPassword) + require.Equal(t, "local", user.AuthSource) +} + +func TestCreateSAMLOnlyUser(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + + form := url.Values{} + form.Set("username", "samluser") + form.Set("password", "") + form.Set("is_staff", "on") + form.Set("auth_source", "saml") + + rec := httptest.NewRecorder() + req := newAuthenticatedFormRequest(t, server, sessionManager, http.MethodPost, "/admin/users/new/", form, "admin") + serveProtected(server, rec, req, server.handleNewUser) + + require.Equal(t, http.StatusSeeOther, rec.Code) + + user, err := memStore.GetUserByUsername("samluser") + require.NoError(t, err) + require.True(t, user.IsStaff) + require.False(t, user.LocalLoginEnabled) + require.Equal(t, "saml", user.AuthSource) + require.Equal(t, "", user.PasswordHash) +} diff --git a/internal/app/cleanup.go b/internal/app/cleanup.go new file mode 100644 index 0000000..27538e8 --- /dev/null +++ b/internal/app/cleanup.go @@ -0,0 +1,29 @@ +package app + +import "time" + +const requestCleanupAfterApproval = 7 * 24 * time.Hour + +func (s *Server) startRequestCleanupJob() { + if s.settings.RequestCleanupInterval <= 0 { + return + } + ticker := time.NewTicker(s.settings.RequestCleanupInterval) + go func() { + for range ticker.C { + s.cleanupOldRequests() + } + }() +} + +func (s *Server) cleanupOldRequests() { + cutoff := time.Now().Add(-requestCleanupAfterApproval) + updated, err := s.store.CleanupRequests(cutoff) + if err != nil { + s.logger.Printf("cleanup requests failed: %v", err) + return + } + if updated > 0 { + s.logger.Printf("cleanup requests updated %d rows", updated) + } +} diff --git a/internal/app/cleanup_test.go b/internal/app/cleanup_test.go new file mode 100644 index 0000000..f1c8529 --- /dev/null +++ b/internal/app/cleanup_test.go @@ -0,0 +1,41 @@ +package app + +import ( + "io" + "log" + "testing" + "time" + + "crypt-server/internal/store" + "github.com/stretchr/testify/require" +) + +type cleanupStoreStub struct { + store.Store + called bool + cutoff time.Time +} + +func (s *cleanupStoreStub) CleanupRequests(approvedBefore time.Time) (int, error) { + s.called = true + s.cutoff = approvedBefore + return 0, nil +} + +func TestCleanupOldRequestsUsesExpectedCutoff(t *testing.T) { + stub := &cleanupStoreStub{} + server := &Server{ + store: stub, + logger: log.New(io.Discard, "", 0), + } + + before := time.Now() + server.cleanupOldRequests() + after := time.Now() + + require.True(t, stub.called) + expectedLower := before.Add(-requestCleanupAfterApproval) + expectedUpper := after.Add(-requestCleanupAfterApproval) + require.True(t, stub.cutoff.After(expectedLower) || stub.cutoff.Equal(expectedLower)) + require.True(t, stub.cutoff.Before(expectedUpper) || stub.cutoff.Equal(expectedUpper)) +} diff --git a/internal/app/csrf.go b/internal/app/csrf.go new file mode 100644 index 0000000..9df44c9 --- /dev/null +++ b/internal/app/csrf.go @@ -0,0 +1,71 @@ +package app + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "net/http" +) + +type CSRFManager struct { + cookieName string + tokenBytes int +} + +func NewCSRFManager(cookieName string, tokenBytes int) *CSRFManager { + return &CSRFManager{cookieName: cookieName, tokenBytes: tokenBytes} +} + +func (m *CSRFManager) EnsureToken(w http.ResponseWriter, r *http.Request, secure bool) (string, error) { + if token := m.TokenFromRequest(r); token != "" { + return token, nil + } + token, err := m.GenerateToken() + if err != nil { + return "", err + } + m.SetCookie(w, token, secure) + return token, nil +} + +func (m *CSRFManager) GenerateToken() (string, error) { + buf := make([]byte, m.tokenBytes) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return base64.RawStdEncoding.EncodeToString(buf), nil +} + +func (m *CSRFManager) TokenFromRequest(r *http.Request) string { + cookie, err := r.Cookie(m.cookieName) + if err != nil { + return "" + } + return cookie.Value +} + +func (m *CSRFManager) SetCookie(w http.ResponseWriter, token string, secure bool) { + http.SetCookie(w, &http.Cookie{ + Name: m.cookieName, + Value: token, + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + Secure: secure, + }) +} + +func (m *CSRFManager) ValidateRequest(r *http.Request) bool { + cookieToken := m.TokenFromRequest(r) + if cookieToken == "" { + return false + } + if err := r.ParseForm(); err != nil { + return false + } + formToken := r.FormValue("csrf_token") + if formToken == "" { + return false + } + return subtle.ConstantTimeCompare([]byte(cookieToken), []byte(formToken)) == 1 +} diff --git a/internal/app/csrf_test.go b/internal/app/csrf_test.go new file mode 100644 index 0000000..de614bb --- /dev/null +++ b/internal/app/csrf_test.go @@ -0,0 +1,49 @@ +package app + +import ( + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCSRFBlocksMissingToken(t *testing.T) { + server, _, _ := newTestServer(t) + + form := url.Values{} + form.Set("field", "value") + req := httptest.NewRequest(http.MethodPost, "/new/computer/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rec := httptest.NewRecorder() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + server.withCSRF(handler).ServeHTTP(rec, req) + + require.Equal(t, http.StatusForbidden, rec.Code) +} + +func TestCSRFAcceptsValidToken(t *testing.T) { + server, _, _ := newTestServer(t) + csrfToken, err := server.csrfManager.GenerateToken() + require.NoError(t, err) + + form := url.Values{} + form.Set("field", "value") + form.Set("csrf_token", csrfToken) + req := httptest.NewRequest(http.MethodPost, "/new/computer/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.AddCookie(&http.Cookie{Name: server.csrfManager.cookieName, Value: csrfToken}) + rec := httptest.NewRecorder() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + server.withCSRF(handler).ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) +} diff --git a/internal/app/handlers.go b/internal/app/handlers.go new file mode 100644 index 0000000..aba1516 --- /dev/null +++ b/internal/app/handlers.go @@ -0,0 +1,1427 @@ +package app + +import ( + "context" + "crypt-server/internal/store" + "encoding/json" + "errors" + "fmt" + "html" + "net" + "net/http" + "strconv" + "strings" + "time" + "unicode" +) + +func (s *Server) handleIndex(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" { + http.NotFound(w, r) + return + } + + data := TemplateData{ + Title: "Crypt", + User: s.currentUser(r), + } + computers, err := s.store.ListComputers() + if err != nil { + s.renderError(w, err) + return + } + outstanding, err := s.store.ListOutstandingRequests() + if err != nil { + s.renderError(w, err) + return + } + data.Computers = computers + data.OutstandingCount = len(outstanding) + + if err := s.renderTemplate(w, r, "index", data); err != nil { + s.renderError(w, err) + } +} + +func (s *Server) handleTableAjax(w http.ResponseWriter, r *http.Request) { + data := map[string]any{} + draw := 0 + if raw := r.URL.Query().Get("args"); raw != "" { + var payload map[string]any + if err := json.Unmarshal([]byte(raw), &payload); err == nil { + if value, ok := payload["draw"].(float64); ok { + draw = int(value) + } + } + } + + computers, err := s.store.ListComputers() + if err != nil { + s.renderError(w, err) + return + } + data["draw"] = draw + data["recordsTotal"] = len(computers) + data["recordsFiltered"] = len(computers) + + rows := make([][]string, 0, len(computers)) + for _, computer := range computers { + serial := html.EscapeString(computer.Serial) + computerName := html.EscapeString(computer.ComputerName) + username := html.EscapeString(computer.Username) + lastCheckin := "" + if !computer.LastCheckin.IsZero() { + lastCheckin = computer.LastCheckin.Format("2006-01-02 15:04") + } + + link := fmt.Sprintf("/info/%d/", computer.ID) + rows = append(rows, []string{ + fmt.Sprintf("%s", link, serial), + fmt.Sprintf("%s", link, computerName), + username, + lastCheckin, + fmt.Sprintf("Info", link), + }) + } + + data["data"] = rows + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(data); err != nil { + s.renderError(w, err) + } +} + +func (s *Server) handleNewComputer(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + data := TemplateData{Title: "New Computer", User: s.currentUser(r)} + if err := s.renderTemplate(w, r, "new_computer", data); err != nil { + s.renderError(w, err) + } + case http.MethodPost: + if err := r.ParseForm(); err != nil { + s.renderError(w, err) + return + } + serial := strings.TrimSpace(r.FormValue("serial")) + username := strings.TrimSpace(r.FormValue("username")) + computerName := strings.TrimSpace(r.FormValue("computername")) + if serial == "" || computerName == "" { + data := TemplateData{ + Title: "New Computer", + User: s.currentUser(r), + ErrorMessage: "Serial number and computer name are required.", + } + if err := s.renderTemplate(w, r, "new_computer", data); err != nil { + s.renderError(w, err) + } + return + } + computer, err := s.store.AddComputer(serial, username, computerName) + if err != nil { + s.renderError(w, err) + return + } + http.Redirect(w, r, fmt.Sprintf("/info/%d/", computer.ID), http.StatusSeeOther) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func (s *Server) handleNewSecret(w http.ResponseWriter, r *http.Request) { + computerID, err := idFromPath("/new/secret/", r.URL.Path) + if err != nil { + http.NotFound(w, r) + return + } + + computer, err := s.store.GetComputerByID(computerID) + if err != nil { + http.NotFound(w, r) + return + } + + switch r.Method { + case http.MethodGet: + data := TemplateData{Title: "New Secret", User: s.currentUser(r), Computer: computer} + if err := s.renderTemplate(w, r, "new_secret", data); err != nil { + s.renderError(w, err) + } + case http.MethodPost: + if err := r.ParseForm(); err != nil { + s.renderError(w, err) + return + } + secretType := strings.TrimSpace(r.FormValue("secret_type")) + secret := strings.TrimSpace(r.FormValue("secret")) + rotationRequired := r.FormValue("rotation_required") == "on" + + if secretType == "" || secret == "" { + data := TemplateData{ + Title: "New Secret", + User: s.currentUser(r), + Computer: computer, + ErrorMessage: "Secret type and value are required.", + } + if err := s.renderTemplate(w, r, "new_secret", data); err != nil { + s.renderError(w, err) + } + return + } + + _, newSecretEscrowed, err := s.store.AddSecret(computer.ID, secretType, secret, rotationRequired) + if err != nil { + s.renderError(w, err) + return + } + user := s.currentUser(r) + if newSecretEscrowed { + s.logger.Printf("secret escrowed manually: serial=%s type=%s by_user=%s", computer.Serial, secretType, user.Username) + } else { + s.logger.Printf("secret updated manually: serial=%s type=%s by_user=%s", computer.Serial, secretType, user.Username) + } + http.Redirect(w, r, fmt.Sprintf("/info/%d/", computer.ID), http.StatusSeeOther) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func (s *Server) handleComputerInfo(w http.ResponseWriter, r *http.Request) { + identifier := strings.TrimPrefix(r.URL.Path, "/info/") + identifier = strings.TrimSuffix(identifier, "/") + if identifier == "" { + http.NotFound(w, r) + return + } + + computer, err := s.lookupComputer(identifier) + if err != nil { + http.NotFound(w, r) + return + } + + secrets, err := s.store.ListSecretsByComputer(computer.ID) + if err != nil { + s.renderError(w, err) + return + } + + data := TemplateData{ + Title: "Computer Info", + User: s.currentUser(r), + Computer: computer, + Secrets: secrets, + } + if err := s.renderTemplate(w, r, "computer_info", data); err != nil { + s.renderError(w, err) + } +} + +func (s *Server) handleSecretInfo(w http.ResponseWriter, r *http.Request) { + secretID, err := idFromPath("/info/secret/", r.URL.Path) + if err != nil { + http.NotFound(w, r) + return + } + + secret, err := s.store.GetSecretByID(secretID) + if err != nil { + http.NotFound(w, r) + return + } + + computer, err := s.store.GetComputerByID(secret.ComputerID) + if err != nil { + http.NotFound(w, r) + return + } + + requests, err := s.store.ListRequestsBySecret(secret.ID) + if err != nil { + s.renderError(w, err) + return + } + canRequest := true + for _, request := range requests { + if request.RequestingUser == s.currentUser(r).Username && request.Approved == nil { + canRequest = false + } + } + approved := false + approvedRequestID := 0 + for _, request := range requests { + if request.RequestingUser == s.currentUser(r).Username && request.Approved != nil && *request.Approved { + approved = true + approvedRequestID = request.ID + } + } + + data := TemplateData{ + Title: "Secret Info", + User: s.currentUser(r), + Computer: computer, + Secret: secret, + CanRequest: canRequest, + RequestApproved: approved, + ApprovedRequestID: approvedRequestID, + RequestsForSecret: requests, + } + if err := s.renderTemplate(w, r, "secret_info", data); err != nil { + s.renderError(w, err) + } +} + +func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) { + secretID, err := idFromPath("/request/", r.URL.Path) + if err != nil { + http.NotFound(w, r) + return + } + + secret, err := s.store.GetSecretByID(secretID) + if err != nil { + http.NotFound(w, r) + return + } + + computer, err := s.store.GetComputerByID(secret.ComputerID) + if err != nil { + http.NotFound(w, r) + return + } + + switch r.Method { + case http.MethodGet: + data := TemplateData{Title: "Request Secret", User: s.currentUser(r), Secret: secret, Computer: computer} + if err := s.renderTemplate(w, r, "request", data); err != nil { + s.renderError(w, err) + } + case http.MethodPost: + if err := r.ParseForm(); err != nil { + s.renderError(w, err) + return + } + reason := strings.TrimSpace(r.FormValue("reason_for_request")) + user := s.currentUser(r) + var approved *bool + var approver string + if user.CanApprove && s.settings.ApproveOwn { + approvedValue := true + approved = &approvedValue + approver = user.Username + } + _, err := s.store.AddRequest(secret.ID, user.Username, reason, approver, approved) + if err != nil { + s.renderError(w, err) + return + } + http.Redirect(w, r, fmt.Sprintf("/info/secret/%d/", secret.ID), http.StatusSeeOther) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func (s *Server) handleApprove(w http.ResponseWriter, r *http.Request) { + requestID, err := idFromPath("/approve/", r.URL.Path) + if err != nil { + http.NotFound(w, r) + return + } + + req, err := s.store.GetRequestByID(requestID) + if err != nil { + http.NotFound(w, r) + return + } + + secret, err := s.store.GetSecretByID(req.SecretID) + if err != nil { + http.NotFound(w, r) + return + } + + computer, err := s.store.GetComputerByID(secret.ComputerID) + if err != nil { + http.NotFound(w, r) + return + } + + switch r.Method { + case http.MethodGet: + if !s.canApproveRequest(r, req) { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + data := TemplateData{Title: "Approve Request", User: s.currentUser(r), Request: req, Secret: secret, Computer: computer} + if err := s.renderTemplate(w, r, "approve", data); err != nil { + s.renderError(w, err) + } + case http.MethodPost: + if !s.canApproveRequest(r, req) { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + if err := r.ParseForm(); err != nil { + s.renderError(w, err) + return + } + approvedValue := r.FormValue("approved") == "1" + reason := strings.TrimSpace(r.FormValue("reason_for_approval")) + if _, err := s.store.ApproveRequest(req.ID, approvedValue, reason, s.currentUser(r).Username); err != nil { + s.renderError(w, err) + return + } + http.Redirect(w, r, "/manage-requests/", http.StatusSeeOther) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func (s *Server) handleRetrieve(w http.ResponseWriter, r *http.Request) { + requestID, err := idFromPath("/retrieve/", r.URL.Path) + if err != nil { + http.NotFound(w, r) + return + } + + req, err := s.store.GetRequestByID(requestID) + if err != nil { + http.NotFound(w, r) + return + } + if req.Approved == nil || !*req.Approved { + http.Error(w, "request not approved", http.StatusForbidden) + return + } + user := s.currentUser(r) + if user.Username != req.RequestingUser && !user.CanApprove { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + + secret, err := s.store.GetSecretByID(req.SecretID) + if err != nil { + http.NotFound(w, r) + return + } + if s.settings.RotateViewedSecrets { + updated, err := s.store.SetSecretRotationRequired(secret.ID, true) + if err != nil { + s.renderError(w, err) + return + } + secret = updated + } + + computer, err := s.store.GetComputerByID(secret.ComputerID) + if err != nil { + http.NotFound(w, r) + return + } + + s.logger.Printf("secret retrieved: serial=%s type=%s by_user=%s requested_by=%s", computer.Serial, secret.SecretType, user.Username, req.RequestingUser) + + secretChars := make([]SecretChar, 0, len(secret.Secret)) + for _, char := range secret.Secret { + entry := SecretChar{Char: string(char), Class: "other"} + if unicode.IsLetter(char) { + entry.Class = "letter" + } else if unicode.IsDigit(char) { + entry.Class = "number" + } + secretChars = append(secretChars, entry) + } + + data := TemplateData{ + Title: "Retrieve Secret", + User: s.currentUser(r), + Request: req, + Secret: secret, + Computer: computer, + SecretChars: secretChars, + } + if err := s.renderTemplate(w, r, "retrieve", data); err != nil { + s.renderError(w, err) + } +} + +func (s *Server) handleManageRequests(w http.ResponseWriter, r *http.Request) { + if !s.currentUser(r).CanApprove { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + requests, err := s.store.ListOutstandingRequests() + if err != nil { + s.renderError(w, err) + return + } + views := make([]RequestView, 0, len(requests)) + for _, req := range requests { + secret, err := s.store.GetSecretByID(req.SecretID) + if err != nil { + continue + } + computer, err := s.store.GetComputerByID(secret.ComputerID) + if err != nil { + continue + } + views = append(views, RequestView{ + ID: req.ID, + Serial: computer.Serial, + ComputerName: computer.ComputerName, + RequestingUser: req.RequestingUser, + ReasonForRequest: req.ReasonForRequest, + DateRequested: req.DateRequested.Format(store.DateTimeFormat), + }) + } + data := TemplateData{Title: "Manage Requests", User: s.currentUser(r), ManageRequests: views} + if err := s.renderTemplate(w, r, "manage_requests", data); err != nil { + s.renderError(w, err) + } +} + +func (s *Server) handleAdminUsers(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/admin/users/" { + s.handleUserList(w, r) + return + } + if r.URL.Path == "/admin/users/new/" { + s.handleNewUser(w, r) + return + } + if strings.HasSuffix(r.URL.Path, "/edit/") { + s.handleUserEdit(w, r) + return + } + if strings.HasSuffix(r.URL.Path, "/password/") { + s.handleUserPassword(w, r) + return + } + if strings.HasSuffix(r.URL.Path, "/delete/") { + s.handleUserDelete(w, r) + return + } + http.NotFound(w, r) +} + +func (s *Server) handleUserList(w http.ResponseWriter, r *http.Request) { + if !s.currentUser(r).IsStaff { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + users, err := s.store.ListUsers() + if err != nil { + s.renderError(w, err) + return + } + data := TemplateData{Title: "Users", User: s.currentUser(r), Users: users} + if err := s.renderTemplate(w, r, "user_list", data); err != nil { + s.renderError(w, err) + } +} + +func (s *Server) handleNewUser(w http.ResponseWriter, r *http.Request) { + if !s.currentUser(r).IsStaff { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + switch r.Method { + case http.MethodGet: + data := TemplateData{ + Title: "New User", + User: s.currentUser(r), + NewUser: UserForm{ + LocalLoginEnabled: true, + AuthSource: "local", + }, + } + if err := s.renderTemplate(w, r, "user_new", data); err != nil { + s.renderError(w, err) + } + case http.MethodPost: + if err := r.ParseForm(); err != nil { + s.renderError(w, err) + return + } + username := strings.TrimSpace(r.FormValue("username")) + password := r.FormValue("password") + isStaff := r.FormValue("is_staff") == "on" + canApprove := r.FormValue("can_approve") == "on" + localLoginEnabled := r.FormValue("local_login_enabled") == "on" + mustReset := r.FormValue("must_reset_password") == "on" + authSource := strings.TrimSpace(r.FormValue("auth_source")) + if authSource == "" { + authSource = "local" + } + if username == "" || (localLoginEnabled && password == "") { + data := TemplateData{ + Title: "New User", + User: s.currentUser(r), + ErrorMessage: "Username and password are required when local login is enabled.", + NewUser: UserForm{ + Username: username, + IsStaff: isStaff, + CanApprove: canApprove, + LocalLoginEnabled: localLoginEnabled, + MustResetPassword: mustReset, + AuthSource: authSource, + }, + } + if err := s.renderTemplate(w, r, "user_new", data); err != nil { + s.renderError(w, err) + } + return + } + if _, err := s.store.GetUserByUsername(username); err == nil { + data := TemplateData{ + Title: "New User", + User: s.currentUser(r), + ErrorMessage: "Username already exists.", + NewUser: UserForm{ + Username: username, + IsStaff: isStaff, + CanApprove: canApprove, + LocalLoginEnabled: localLoginEnabled, + MustResetPassword: mustReset, + AuthSource: authSource, + }, + } + if err := s.renderTemplate(w, r, "user_new", data); err != nil { + s.renderError(w, err) + } + return + } else if err != store.ErrNotFound { + s.renderError(w, err) + return + } + var passwordHash string + if localLoginEnabled { + var err error + passwordHash, err = hashPassword(password) + if err != nil { + s.renderError(w, err) + return + } + } + if _, err := s.store.AddUser(username, passwordHash, isStaff, canApprove, localLoginEnabled, mustReset, authSource); err != nil { + s.renderError(w, err) + return + } + s.logger.Printf("user created: username=%s is_staff=%t can_approve=%t by_user=%s", username, isStaff, canApprove, s.currentUser(r).Username) + if _, err := s.store.AddAuditEvent(s.currentUser(r).Username, username, "user_created", "", clientIP(r)); err != nil { + s.renderError(w, err) + return + } + http.Redirect(w, r, "/admin/users/", http.StatusSeeOther) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func (s *Server) handleUserEdit(w http.ResponseWriter, r *http.Request) { + if !s.currentUser(r).IsStaff { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + userID, err := idFromPath("/admin/users/", strings.TrimSuffix(r.URL.Path, "/edit/")+"/") + if err != nil { + http.NotFound(w, r) + return + } + user, err := s.store.GetUserByID(userID) + if err != nil { + http.NotFound(w, r) + return + } + switch r.Method { + case http.MethodGet: + data := TemplateData{ + Title: "Edit User", + User: s.currentUser(r), + AdminUser: user, + NewUser: UserForm{ + Username: user.Username, + IsStaff: user.IsStaff, + CanApprove: user.CanApprove, + LocalLoginEnabled: user.LocalLoginEnabled, + MustResetPassword: user.MustResetPassword, + AuthSource: user.AuthSource, + }, + } + if err := s.renderTemplate(w, r, "user_edit", data); err != nil { + s.renderError(w, err) + } + case http.MethodPost: + if err := r.ParseForm(); err != nil { + s.renderError(w, err) + return + } + username := strings.TrimSpace(r.FormValue("username")) + isStaff := r.FormValue("is_staff") == "on" + canApprove := r.FormValue("can_approve") == "on" + localLoginEnabled := r.FormValue("local_login_enabled") == "on" + mustReset := r.FormValue("must_reset_password") == "on" + authSource := strings.TrimSpace(r.FormValue("auth_source")) + if authSource == "" { + authSource = "local" + } + if username == "" { + data := TemplateData{ + Title: "Edit User", + User: s.currentUser(r), + AdminUser: user, + ErrorMessage: "Username is required.", + NewUser: UserForm{ + Username: username, + IsStaff: isStaff, + CanApprove: canApprove, + LocalLoginEnabled: localLoginEnabled, + MustResetPassword: mustReset, + AuthSource: authSource, + }, + } + if err := s.renderTemplate(w, r, "user_edit", data); err != nil { + s.renderError(w, err) + } + return + } + if existing, err := s.store.GetUserByUsername(username); err == nil && existing.ID != user.ID { + data := TemplateData{ + Title: "Edit User", + User: s.currentUser(r), + AdminUser: user, + ErrorMessage: "Username already exists.", + NewUser: UserForm{ + Username: username, + IsStaff: isStaff, + CanApprove: canApprove, + LocalLoginEnabled: localLoginEnabled, + MustResetPassword: mustReset, + AuthSource: authSource, + }, + } + if err := s.renderTemplate(w, r, "user_edit", data); err != nil { + s.renderError(w, err) + } + return + } else if err != nil && err != store.ErrNotFound { + s.renderError(w, err) + return + } + updated, err := s.store.UpdateUser(user.ID, username, isStaff, canApprove, localLoginEnabled, mustReset, authSource) + if err != nil { + s.renderError(w, err) + return + } + s.logger.Printf("user edited: username=%s is_staff=%t can_approve=%t by_user=%s", username, isStaff, canApprove, s.currentUser(r).Username) + if reason := buildUserUpdateReason(user, updated); reason != "" { + if _, err := s.store.AddAuditEvent(s.currentUser(r).Username, updated.Username, "user_updated", reason, clientIP(r)); err != nil { + s.renderError(w, err) + return + } + } + if user.MustResetPassword != mustReset { + action := "force_reset_disabled" + if mustReset { + action = "force_reset_enabled" + } + if _, err := s.store.AddAuditEvent(s.currentUser(r).Username, updated.Username, action, "", clientIP(r)); err != nil { + s.renderError(w, err) + return + } + } + http.Redirect(w, r, fmt.Sprintf("/admin/users/%d/edit/", updated.ID), http.StatusSeeOther) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func (s *Server) handleUserPassword(w http.ResponseWriter, r *http.Request) { + if !s.currentUser(r).IsStaff { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + userID, err := idFromPath("/admin/users/", strings.TrimSuffix(r.URL.Path, "/password/")+"/") + if err != nil { + http.NotFound(w, r) + return + } + user, err := s.store.GetUserByID(userID) + if err != nil { + http.NotFound(w, r) + return + } + switch r.Method { + case http.MethodGet: + data := TemplateData{Title: "Reset Password", User: s.currentUser(r), AdminUser: user} + if err := s.renderTemplate(w, r, "user_password", data); err != nil { + s.renderError(w, err) + } + case http.MethodPost: + if err := r.ParseForm(); err != nil { + s.renderError(w, err) + return + } + password := r.FormValue("password") + if password == "" { + data := TemplateData{ + Title: "Reset Password", + User: s.currentUser(r), + AdminUser: user, + ErrorMessage: "Password is required.", + } + if err := s.renderTemplate(w, r, "user_password", data); err != nil { + s.renderError(w, err) + } + return + } + passwordHash, err := hashPassword(password) + if err != nil { + s.renderError(w, err) + return + } + if _, err := s.store.UpdateUserPassword(user.ID, passwordHash, false); err != nil { + s.renderError(w, err) + return + } + if _, err := s.store.AddAuditEvent(s.currentUser(r).Username, user.Username, "password_reset", "", clientIP(r)); err != nil { + s.renderError(w, err) + return + } + http.Redirect(w, r, fmt.Sprintf("/admin/users/%d/edit/", user.ID), http.StatusSeeOther) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func (s *Server) handleAuditLog(w http.ResponseWriter, r *http.Request) { + if !s.currentUser(r).IsStaff { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + query := strings.TrimSpace(r.URL.Query().Get("q")) + page := parsePage(r.URL.Query().Get("page")) + const pageSize = 50 + offset := (page - 1) * pageSize + + var ( + events []*store.AuditEvent + total int + err error + ) + if query == "" { + total, err = s.store.CountAuditEvents() + if err != nil { + s.renderError(w, err) + return + } + events, err = s.store.ListAuditEventsPaged(pageSize, offset) + } else { + total, err = s.store.CountSearchAuditEvents(query) + if err != nil { + s.renderError(w, err) + return + } + events, err = s.store.SearchAuditEventsPaged(query, pageSize, offset) + } + if err != nil { + s.renderError(w, err) + return + } + + totalPages := 1 + if total > 0 { + totalPages = (total + pageSize - 1) / pageSize + if page > totalPages { + page = totalPages + } + } + + data := TemplateData{ + Title: "Audit Log", + User: s.currentUser(r), + AuditEvents: events, + AuditSearch: query, + AuditPage: page, + AuditPageSize: pageSize, + AuditTotal: total, + AuditTotalPages: totalPages, + } + if err := s.renderTemplate(w, r, "audit_log", data); err != nil { + s.renderError(w, err) + } +} + +func (s *Server) handleUserDelete(w http.ResponseWriter, r *http.Request) { + if !s.currentUser(r).IsStaff { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + userID, err := idFromPath("/admin/users/", strings.TrimSuffix(r.URL.Path, "/delete/")+"/") + if err != nil { + http.NotFound(w, r) + return + } + user, err := s.store.GetUserByID(userID) + if err != nil { + http.NotFound(w, r) + return + } + switch r.Method { + case http.MethodGet: + data := TemplateData{Title: "Delete User", User: s.currentUser(r), AdminUser: user} + if err := s.renderTemplate(w, r, "user_delete", data); err != nil { + s.renderError(w, err) + } + case http.MethodPost: + if s.currentUser(r).ID == user.ID { + data := TemplateData{ + Title: "Delete User", + User: s.currentUser(r), + AdminUser: user, + ErrorMessage: "You cannot delete your own account.", + } + if err := s.renderTemplate(w, r, "user_delete", data); err != nil { + s.renderError(w, err) + } + return + } + if err := s.store.DeleteUser(user.ID); err != nil { + s.renderError(w, err) + return + } + s.logger.Printf("user deleted: username=%s by_user=%s", user.Username, s.currentUser(r).Username) + if _, err := s.store.AddAuditEvent(s.currentUser(r).Username, user.Username, "user_deleted", "", clientIP(r)); err != nil { + s.renderError(w, err) + return + } + http.Redirect(w, r, "/admin/users/", http.StatusSeeOther) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func (s *Server) handlePasswordChange(w http.ResponseWriter, r *http.Request) { + user := s.currentUser(r) + if !user.LocalLoginEnabled { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + switch r.Method { + case http.MethodGet: + data := TemplateData{ + Title: "Change Password", + User: user, + PasswordChangeRequiresCurrent: true, + } + if err := s.renderTemplate(w, r, "password_change", data); err != nil { + s.renderError(w, err) + } + case http.MethodPost: + if err := r.ParseForm(); err != nil { + s.renderError(w, err) + return + } + current := r.FormValue("current_password") + next := r.FormValue("new_password") + dbUser, err := s.store.GetUserByUsername(user.Username) + if err != nil || dbUser.PasswordHash == "" || !verifyPassword(current, dbUser.PasswordHash) { + data := TemplateData{ + Title: "Change Password", + User: user, + ErrorMessage: "Current password is incorrect.", + PasswordChangeRequiresCurrent: true, + } + if err := s.renderTemplate(w, r, "password_change", data); err != nil { + s.renderError(w, err) + } + return + } + hash, err := hashPassword(next) + if err != nil { + s.renderError(w, err) + return + } + if _, err := s.store.UpdateUserPassword(user.ID, hash, false); err != nil { + s.renderError(w, err) + return + } + http.Redirect(w, r, "/", http.StatusSeeOther) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func (s *Server) handlePasswordReset(w http.ResponseWriter, r *http.Request) { + user := s.currentUser(r) + if !user.LocalLoginEnabled { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + if !user.MustResetPassword { + http.Redirect(w, r, "/password/change/", http.StatusSeeOther) + return + } + switch r.Method { + case http.MethodGet: + data := TemplateData{ + Title: "Reset Password", + User: user, + PasswordChangeRequiresCurrent: false, + } + if err := s.renderTemplate(w, r, "password_change", data); err != nil { + s.renderError(w, err) + } + case http.MethodPost: + if err := r.ParseForm(); err != nil { + s.renderError(w, err) + return + } + next := r.FormValue("new_password") + hash, err := hashPassword(next) + if err != nil { + s.renderError(w, err) + return + } + if _, err := s.store.UpdateUserPassword(user.ID, hash, false); err != nil { + s.renderError(w, err) + return + } + http.Redirect(w, r, "/", http.StatusSeeOther) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func (s *Server) handleSAMLLogin(w http.ResponseWriter, r *http.Request) { + if s.samlSP == nil { + http.Error(w, "SAML login not configured", http.StatusNotImplemented) + return + } + s.samlSP.HandleStartAuthFlow(w, r) +} + +func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + data := TemplateData{Title: "Login", User: s.currentUser(r)} + if err := s.renderTemplate(w, r, "login", data); err != nil { + s.renderError(w, err) + } + case http.MethodPost: + if err := r.ParseForm(); err != nil { + s.renderError(w, err) + return + } + username := strings.TrimSpace(r.FormValue("username")) + password := r.FormValue("password") + next := r.FormValue("next") + + user, err := s.store.GetUserByUsername(username) + if err != nil || !user.LocalLoginEnabled || user.PasswordHash == "" { + s.renderLoginError(w, r, "Invalid username or password.") + return + } + if !verifyPassword(password, user.PasswordHash) { + s.renderLoginError(w, r, "Invalid username or password.") + return + } + token, err := s.sessionManager.Create(user.Username) + if err != nil { + s.renderError(w, err) + return + } + s.sessionManager.SetCookie(w, token, s.settings.CookieSecure) + if user.MustResetPassword && user.LocalLoginEnabled { + next = "/password/reset/" + } + if next == "" || !strings.HasPrefix(next, "/") { + next = "/" + } + http.Redirect(w, r, next, http.StatusSeeOther) + default: + w.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) { + s.sessionManager.ClearCookie(w, s.settings.CookieSecure) + http.Redirect(w, r, "/login/", http.StatusSeeOther) +} + +func (s *Server) handleCheckin(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + serial := r.FormValue("serial") + if serial == "" { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + recoveryPass := r.FormValue("recovery_password") + if recoveryPass == "" { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + userName := r.FormValue("username") + if userName == "" { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + macName := r.FormValue("macname") + if macName == "" { + macName = serial + } + secretType := r.FormValue("secret_type") + if secretType == "" { + secretType = "recovery_key" + } + secretType = strings.TrimSpace(secretType) + if secretType == "" { + secretType = "recovery_key" + } + + now := time.Now() + computer, err := s.store.UpsertComputer(serial, userName, macName, now) + if err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + + secret, newSecretEscrowed, err := s.store.AddSecret(computer.ID, secretType, recoveryPass, false) + if err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + + if newSecretEscrowed { + s.logger.Printf("secret escrowed: serial=%s type=%s username=%s macname=%s", serial, secretType, userName, macName) + } else { + s.logger.Printf("secret updated: serial=%s type=%s username=%s macname=%s", serial, secretType, userName, macName) + } + + payload := map[string]any{ + "serial": computer.Serial, + "username": computer.Username, + "rotation_required": secret.RotationRequired, + "new_secret_escrowed": newSecretEscrowed, + } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(payload); err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } +} + +func (s *Server) handleVerify(w http.ResponseWriter, r *http.Request) { + path := strings.TrimPrefix(r.URL.Path, "/verify/") + path = strings.TrimSuffix(path, "/") + parts := strings.Split(path, "/") + if len(parts) != 2 { + http.NotFound(w, r) + return + } + serial := parts[0] + secretType := parts[1] + if serial == "" || secretType == "" { + http.NotFound(w, r) + return + } + + computer, err := s.store.GetComputerBySerial(serial) + if err != nil { + http.NotFound(w, r) + return + } + secret, err := s.store.GetLatestSecretByComputerAndType(computer.ID, secretType) + payload := map[string]any{} + if err == nil { + payload["escrowed"] = true + payload["date_escrowed"] = secret.DateEscrowed.Format(time.RFC3339) + } else if err == store.ErrNotFound { + payload["escrowed"] = false + } else { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(payload); err != nil { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } +} + +func (s *Server) currentUser(r *http.Request) User { + if user := userFromContext(r.Context()); user != nil { + return *user + } + return User{} +} + +func (s *Server) renderError(w http.ResponseWriter, err error) { + s.logger.Printf("handler error: %v", err) + http.Error(w, "Something went wrong", http.StatusInternalServerError) +} + +func (s *Server) renderTemplate(w http.ResponseWriter, r *http.Request, name string, data TemplateData) error { + data.CSRFToken = s.csrfToken(w, r) + data.SAMLAvailable = s.samlSP != nil + data.SAMLLoginURL = s.samlLoginURL() + data.Version = Version + return s.renderer.Render(w, name, data) +} + +func buildUserUpdateReason(before, after *store.User) string { + changes := make([]string, 0, 5) + if before.Username != after.Username { + changes = append(changes, "username") + } + if before.IsStaff != after.IsStaff { + changes = append(changes, "is_staff") + } + if before.CanApprove != after.CanApprove { + changes = append(changes, "can_approve") + } + if before.LocalLoginEnabled != after.LocalLoginEnabled { + changes = append(changes, "local_login_enabled") + } + if before.AuthSource != after.AuthSource { + changes = append(changes, "auth_source") + } + return strings.Join(changes, ",") +} + +func clientIP(r *http.Request) string { + if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" { + parts := strings.Split(forwarded, ",") + if len(parts) > 0 { + return strings.TrimSpace(parts[0]) + } + } + if realIP := strings.TrimSpace(r.Header.Get("X-Real-IP")); realIP != "" { + return realIP + } + host, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr)) + if err == nil { + return host + } + return strings.TrimSpace(r.RemoteAddr) +} + +func parsePage(raw string) int { + value := strings.TrimSpace(raw) + if value == "" { + return 1 + } + page, err := strconv.Atoi(value) + if err != nil || page < 1 { + return 1 + } + return page +} + +func (s *Server) csrfToken(w http.ResponseWriter, r *http.Request) string { + if s.csrfManager == nil { + return "" + } + token, err := s.csrfManager.EnsureToken(w, r, s.settings.CookieSecure) + if err != nil { + return "" + } + return token +} + +func (s *Server) samlLoginURL() string { + if s.samlConfig == nil { + return "/saml/login/" + } + if strings.HasPrefix(s.samlConfig.MetadataURLPath, "/saml2/") { + return "/saml2/login/" + } + return "/saml/login/" +} + +func (s *Server) renderLoginError(w http.ResponseWriter, r *http.Request, message string) { + data := TemplateData{ + Title: "Login", + User: s.currentUser(r), + ErrorMessage: message, + } + if err := s.renderTemplate(w, r, "login", data); err != nil { + s.renderError(w, err) + } +} + +func idFromPath(prefix, path string) (int, error) { + if !strings.HasPrefix(path, prefix) { + return 0, errors.New("invalid path") + } + trimmed := strings.TrimPrefix(path, prefix) + trimmed = strings.TrimSuffix(trimmed, "/") + if trimmed == "" { + return 0, errors.New("missing id") + } + return strconv.Atoi(trimmed) +} + +func (s *Server) lookupComputer(identifier string) (*store.Computer, error) { + if id, err := strconv.Atoi(identifier); err == nil { + return s.store.GetComputerByID(id) + } + return s.store.GetComputerBySerial(identifier) +} + +func (s *Server) loadUserFromRequest(r *http.Request) *User { + if s.sessionManager == nil { + return s.loadUserFromSAML(r) + } + cookie, err := r.Cookie(s.sessionManager.CookieName()) + if err != nil { + return s.loadUserFromSAML(r) + } + username, ok := s.sessionManager.Validate(cookie.Value) + if !ok { + return s.loadUserFromSAML(r) + } + dbUser, err := s.store.GetUserByUsername(username) + if err != nil { + return s.loadUserFromSAML(r) + } + user := mapStoreUser(dbUser) + user.IsAuthenticated = true + if s.settings.AllApprove { + user.CanApprove = true + } + return &user +} + +func (s *Server) loadUserFromSAML(r *http.Request) *User { + if s.samlSP == nil || s.samlConfig == nil { + return nil + } + session, err := s.samlSP.Session.GetSession(r) + if err != nil { + return nil + } + username := usernameFromSAML(session, s.samlConfig) + if username == "" { + return nil + } + attributes := attributesFromSession(session) + groups := groupMembership(attributes, s.samlConfig.GroupsAttribute) + isStaff, canApprove := resolveSAMLPermissions(groups, s.samlConfig) + + dbUser, err := s.store.GetUserByUsername(username) + if err != nil { + if err != store.ErrNotFound || !s.samlConfig.CreateUnknownUser { + return nil + } + newUser, err := s.store.AddUser( + username, + "", + isStaff, + canApprove, + s.samlConfig.DefaultLocalLogin, + s.samlConfig.DefaultMustReset, + s.samlConfig.DefaultAuthSource, + ) + if err != nil { + return nil + } + s.logger.Printf("user created via SAML: username=%s is_staff=%t can_approve=%t", username, isStaff, canApprove) + _, _ = s.store.AddAuditEvent("saml", username, "user_created", "created via SAML login", clientIP(r)) + user := mapStoreUser(newUser) + user.IsAuthenticated = true + return &user + } + + updatedUser := dbUser + shouldUpdate := false + if s.shouldUpdateStaff() { + if dbUser.IsStaff != isStaff { + updatedUser.IsStaff = isStaff + shouldUpdate = true + } + } + if s.shouldUpdateApprover() { + if dbUser.CanApprove != canApprove { + updatedUser.CanApprove = canApprove + shouldUpdate = true + } + } + if s.samlConfig.DefaultAuthSource != "" && dbUser.AuthSource != s.samlConfig.DefaultAuthSource { + updatedUser.AuthSource = s.samlConfig.DefaultAuthSource + shouldUpdate = true + } + if shouldUpdate { + updatedUser, err = s.store.UpdateUser( + dbUser.ID, + dbUser.Username, + updatedUser.IsStaff, + updatedUser.CanApprove, + dbUser.LocalLoginEnabled, + dbUser.MustResetPassword, + updatedUser.AuthSource, + ) + if err != nil { + return nil + } + s.logger.Printf("user updated via SAML: username=%s is_staff=%t can_approve=%t", username, updatedUser.IsStaff, updatedUser.CanApprove) + _, _ = s.store.AddAuditEvent("saml", username, "user_updated", "permissions updated via SAML login", clientIP(r)) + } + user := mapStoreUser(updatedUser) + user.IsAuthenticated = true + return &user +} + +func (s *Server) shouldUpdateStaff() bool { + if s.samlConfig == nil { + return false + } + return len(s.samlConfig.StaffGroups) > 0 || len(s.samlConfig.SuperuserGroups) > 0 +} + +func (s *Server) shouldUpdateApprover() bool { + if s.samlConfig == nil { + return false + } + return len(s.samlConfig.CanApproveGroups) > 0 || len(s.samlConfig.SuperuserGroups) > 0 +} + +func mapStoreUser(user *store.User) User { + return User{ + ID: user.ID, + Username: user.Username, + IsStaff: user.IsStaff, + CanApprove: user.CanApprove, + LocalLoginEnabled: user.LocalLoginEnabled, + MustResetPassword: user.MustResetPassword, + AuthSource: user.AuthSource, + } +} + +func userFromContext(ctx context.Context) *User { + if value := ctx.Value(userContextKey); value != nil { + if user, ok := value.(*User); ok { + return user + } + } + return nil +} + +func (s *Server) canApproveRequest(r *http.Request, req *store.Request) bool { + user := s.currentUser(r) + if !user.CanApprove { + return false + } + if !s.settings.ApproveOwn && user.Username == req.RequestingUser { + return false + } + return true +} diff --git a/internal/app/handlers_test.go b/internal/app/handlers_test.go new file mode 100644 index 0000000..84b8fb9 --- /dev/null +++ b/internal/app/handlers_test.go @@ -0,0 +1,1461 @@ +package app + +import ( + "encoding/base64" + "encoding/json" + "io" + "log" + "net/http" + "net/http/httptest" + "net/url" + "path/filepath" + "strconv" + "strings" + "testing" + "time" + + "crypt-server/internal/crypto" + "crypt-server/internal/migrate" + "crypt-server/internal/store" + "github.com/stretchr/testify/require" +) + +func newTestServer(t *testing.T) (*Server, store.Store, *SessionManager) { + t.Helper() + codec := testCodec(t) + dataStore := newTestSQLiteStore(t, codec) + server, sessionManager := newTestServerWithStore(t, dataStore) + passwordHash := hashPasswordForTest(t, "password") + _, err := dataStore.AddUser("admin", passwordHash, true, true, true, false, "local") + require.NoError(t, err) + return server, dataStore, sessionManager +} + +func newTestSQLiteStore(t *testing.T, codec *crypto.AesGcmCodec) *store.SQLiteStore { + t.Helper() + path := filepath.Join(t.TempDir(), "crypt.db") + sqliteStore, err := store.NewSQLiteStore(path, codec) + require.NoError(t, err) + migrationFS, err := migrate.SubMigrationsFS(migrate.EmbeddedFS, "sqlite") + require.NoError(t, err) + require.NoError(t, migrate.Apply(sqliteStore.DB(), "sqlite", migrationFS)) + return sqliteStore +} + +func newTestServerWithStore(t *testing.T, dataStore store.Store) (*Server, *SessionManager) { + t.Helper() + root := filepath.Join("..", "..") + layout := filepath.Join(root, "web", "templates", "layouts", "base.html") + pages := filepath.Join(root, "web", "templates", "pages") + renderer := NewRenderer(layout, pages) + logger := log.New(io.Discard, "", 0) + sessionManager, err := NewSessionManager([]byte("test-session-key-32-bytes-long!!"), "crypt_session", time.Hour) + require.NoError(t, err) + settings := Settings{ + ApproveOwn: true, + AllApprove: false, + SessionTTL: time.Hour, + CookieSecure: false, + RequestCleanupInterval: 0, + RotateViewedSecrets: false, + } + csrfManager := NewCSRFManager("crypt_csrf", 32) + server := NewServer(dataStore, renderer, logger, sessionManager, csrfManager, nil, nil, settings) + return server, sessionManager +} + +type rotationTrackingStore struct { + *store.SQLiteStore + called bool +} + +func (s *rotationTrackingStore) SetSecretRotationRequired(secretID int, rotationRequired bool) (*store.Secret, error) { + s.called = true + return s.SQLiteStore.SetSecretRotationRequired(secretID, rotationRequired) +} + +type auditPaginationStore struct { + lastLimit int + lastOffset int +} + +func (s *auditPaginationStore) ListAuditEventsPaged(limit, offset int) ([]*store.AuditEvent, error) { + s.lastLimit = limit + s.lastOffset = offset + return []*store.AuditEvent{}, nil +} + +func (s *auditPaginationStore) SearchAuditEventsPaged(query string, limit, offset int) ([]*store.AuditEvent, error) { + s.lastLimit = limit + s.lastOffset = offset + return []*store.AuditEvent{}, nil +} + +func (s *auditPaginationStore) CountAuditEvents() (int, error) { + return 1, nil +} + +func (s *auditPaginationStore) CountSearchAuditEvents(query string) (int, error) { + return 1, nil +} + +func (s *auditPaginationStore) ListAuditEvents() ([]*store.AuditEvent, error) { + return []*store.AuditEvent{}, nil +} + +func (s *auditPaginationStore) SearchAuditEvents(query string) ([]*store.AuditEvent, error) { + return []*store.AuditEvent{}, nil +} + +func (s *auditPaginationStore) AddAuditEvent(actor, targetUser, action, reason, ipAddress string) (*store.AuditEvent, error) { + return &store.AuditEvent{}, nil +} + +func (s *auditPaginationStore) AddComputer(serial, username, computerName string) (*store.Computer, error) { + return nil, store.ErrNotFound +} + +func (s *auditPaginationStore) UpsertComputer(serial, username, computerName string, lastCheckin time.Time) (*store.Computer, error) { + return nil, store.ErrNotFound +} + +func (s *auditPaginationStore) ListComputers() ([]*store.Computer, error) { + return []*store.Computer{}, nil +} + +func (s *auditPaginationStore) GetComputerByID(id int) (*store.Computer, error) { + return nil, store.ErrNotFound +} + +func (s *auditPaginationStore) GetComputerBySerial(serial string) (*store.Computer, error) { + return nil, store.ErrNotFound +} + +func (s *auditPaginationStore) AddSecret(computerID int, secretType, secret string, rotationRequired bool) (*store.Secret, bool, error) { + return nil, false, store.ErrNotFound +} + +func (s *auditPaginationStore) ListSecretsByComputer(computerID int) ([]*store.Secret, error) { + return []*store.Secret{}, nil +} + +func (s *auditPaginationStore) GetSecretByID(id int) (*store.Secret, error) { + return nil, store.ErrNotFound +} + +func (s *auditPaginationStore) GetLatestSecretByComputerAndType(computerID int, secretType string) (*store.Secret, error) { + return nil, store.ErrNotFound +} + +func (s *auditPaginationStore) AddRequest(secretID int, requestingUser, reason string, approvedBy string, approved *bool) (*store.Request, error) { + return nil, store.ErrNotFound +} + +func (s *auditPaginationStore) ListRequestsBySecret(secretID int) ([]*store.Request, error) { + return []*store.Request{}, nil +} + +func (s *auditPaginationStore) ListOutstandingRequests() ([]*store.Request, error) { + return []*store.Request{}, nil +} + +func (s *auditPaginationStore) GetRequestByID(id int) (*store.Request, error) { + return nil, store.ErrNotFound +} + +func (s *auditPaginationStore) ApproveRequest(requestID int, approved bool, reason, approver string) (*store.Request, error) { + return nil, store.ErrNotFound +} + +func (s *auditPaginationStore) AddUser(username, passwordHash string, isStaff, canApprove, localLoginEnabled, mustResetPassword bool, authSource string) (*store.User, error) { + return nil, nil +} + +func (s *auditPaginationStore) GetUserByUsername(username string) (*store.User, error) { + return &store.User{ID: 1, Username: username, IsStaff: true, LocalLoginEnabled: true}, nil +} + +func (s *auditPaginationStore) ListUsers() ([]*store.User, error) { + return []*store.User{}, nil +} + +func (s *auditPaginationStore) GetUserByID(id int) (*store.User, error) { + return &store.User{ID: id, Username: "user"}, nil +} + +func (s *auditPaginationStore) UpdateUser(id int, username string, isStaff, canApprove, localLoginEnabled, mustResetPassword bool, authSource string) (*store.User, error) { + return &store.User{ID: id, Username: username}, nil +} + +func (s *auditPaginationStore) UpdateUserPassword(id int, passwordHash string, mustResetPassword bool) (*store.User, error) { + return &store.User{ID: id, Username: "user"}, nil +} + +func (s *auditPaginationStore) DeleteUser(id int) error { + return nil +} + +func (s *auditPaginationStore) CleanupRequests(approvedBefore time.Time) (int, error) { + return 0, nil +} + +func (s *auditPaginationStore) SetSecretRotationRequired(secretID int, rotationRequired bool) (*store.Secret, error) { + return nil, store.ErrNotFound +} + +func (s *auditPaginationStore) IsEmpty() (bool, error) { + return true, nil +} + +func (s *auditPaginationStore) ImportComputer(id int, serial, username, computerName string, lastCheckin time.Time) error { + return nil +} + +func (s *auditPaginationStore) ImportSecret(id, computerID int, secretType, encryptedSecret string, dateEscrowed time.Time, rotationRequired bool) error { + return nil +} + +func (s *auditPaginationStore) ImportRequest(id, secretID int, requestingUser string, approved *bool, authUser, reasonForRequest, reasonForApproval string, dateRequested time.Time, dateApproved *time.Time, current bool) error { + return nil +} + +func (s *auditPaginationStore) ImportUser(id int, username, passwordHash string, isStaff, canApprove, localLoginEnabled, mustResetPassword bool, authSource string) error { + return nil +} + +func TestHandleIndex(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + _, err := memStore.AddComputer("SERIAL1", "user", "Mac") + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/", nil, "admin") + serveProtected(server, rec, req, server.handleIndex) + + require.Equal(t, http.StatusOK, rec.Code) + require.Contains(t, rec.Body.String(), "Serial Number") +} + +func TestHandleTableAjax(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + _, err := memStore.AddComputer("SERIAL2", "user", "iMac") + require.NoError(t, err) + + payload := map[string]any{"draw": 1} + payloadBytes, _ := json.Marshal(payload) + query := url.Values{} + query.Set("args", string(payloadBytes)) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/ajax/?"+query.Encode(), nil, "admin") + serveProtected(server, rec, req, server.handleTableAjax) + + require.Equal(t, http.StatusOK, rec.Code) + + var data map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &data)) + require.Equal(t, float64(1), data["recordsTotal"]) +} + +func TestHandleNewComputerFlow(t *testing.T) { + server, _, sessionManager := newTestServer(t) + form := url.Values{} + form.Set("serial", "SERIAL3") + form.Set("username", "user3") + form.Set("computername", "MacBook Air") + + rec := httptest.NewRecorder() + req := newAuthenticatedFormRequest(t, server, sessionManager, http.MethodPost, "/new/computer/", form, "admin") + serveProtected(server, rec, req, server.handleNewComputer) + + require.Equal(t, http.StatusSeeOther, rec.Code) + require.Contains(t, rec.Header().Get("Location"), "/info/") +} + +func TestRequestApproveRetrieveFlow(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + computer, err := memStore.AddComputer("SERIAL4", "user4", "MacBook Pro") + require.NoError(t, err) + secret, _, err := memStore.AddSecret(computer.ID, "recovery_key", "secret-value", false) + require.NoError(t, err) + + form := url.Values{} + form.Set("reason_for_request", "Need access") + rec := httptest.NewRecorder() + req := newAuthenticatedFormRequest(t, server, sessionManager, http.MethodPost, "/request/"+intToString(secret.ID)+"/", form, "admin") + serveProtected(server, rec, req, server.handleRequest) + + require.Equal(t, http.StatusSeeOther, rec.Code) + + requests, err := memStore.ListRequestsBySecret(secret.ID) + require.NoError(t, err) + require.Len(t, requests, 1) + + infoRec := httptest.NewRecorder() + infoReq := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/info/secret/"+intToString(secret.ID)+"/", nil, "admin") + serveProtected(server, infoRec, infoReq, server.handleSecretInfo) + require.Contains(t, infoRec.Body.String(), "Retrieve Key") + + retrieveRec := httptest.NewRecorder() + retrieveReq := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/retrieve/"+intToString(requests[0].ID)+"/", nil, "admin") + serveProtected(server, retrieveRec, retrieveReq, server.handleRetrieve) + require.Equal(t, http.StatusOK, retrieveRec.Code) + require.Contains(t, retrieveRec.Body.String(), "class=\"letter\">s") +} + +func TestRetrieveMarksRotationRequired(t *testing.T) { + codec := testCodec(t) + sqliteStore := newTestSQLiteStore(t, codec) + dataStore := &rotationTrackingStore{SQLiteStore: sqliteStore} + server, sessionManager := newTestServerWithStore(t, dataStore) + passwordHash := hashPasswordForTest(t, "password") + _, err := dataStore.AddUser("admin", passwordHash, true, true, true, false, "local") + require.NoError(t, err) + server.settings.RotateViewedSecrets = true + require.True(t, server.settings.RotateViewedSecrets) + computer, err := dataStore.AddComputer("SERIALROTATE", "user", "MacBook Pro") + require.NoError(t, err) + secret, _, err := dataStore.AddSecret(computer.ID, "recovery_key", "secret-value", false) + require.NoError(t, err) + initial, err := dataStore.GetSecretByID(secret.ID) + require.NoError(t, err) + require.False(t, initial.RotationRequired) + approved := true + req, err := dataStore.AddRequest(secret.ID, "admin", "Need access", "approver", &approved) + require.NoError(t, err) + + rec := httptest.NewRecorder() + retrieveReq := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/retrieve/"+intToString(req.ID)+"/", nil, "admin") + serveProtected(server, rec, retrieveReq, server.handleRetrieve) + require.Equal(t, http.StatusOK, rec.Code) + require.True(t, dataStore.called) + + var rotation int + row := sqliteStore.DB().QueryRow("SELECT rotation_required FROM secrets WHERE id = ?", secret.ID) + require.NoError(t, row.Scan(&rotation)) + require.Equal(t, 1, rotation) + + updated, err := dataStore.GetSecretByID(secret.ID) + require.NoError(t, err) + require.True(t, updated.RotationRequired) +} + +func TestHandleManageRequests(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + computer, err := memStore.AddComputer("SERIAL5", "user5", "Mac Mini") + require.NoError(t, err) + secret, _, err := memStore.AddSecret(computer.ID, "password", "secret", false) + require.NoError(t, err) + _, err = memStore.AddRequest(secret.ID, "user5", "Need access", "", nil) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/manage-requests/", nil, "admin") + serveProtected(server, rec, req, server.handleManageRequests) + + require.Equal(t, http.StatusOK, rec.Code) + require.Contains(t, rec.Body.String(), "SERIAL5") +} + +func TestUserPasswordResetLogsAuditEvent(t *testing.T) { + codec := testCodec(t) + sqliteStore := newTestSQLiteStore(t, codec) + server, sessionManager := newTestServerWithStore(t, sqliteStore) + passwordHash := hashPasswordForTest(t, "password") + _, err := sqliteStore.AddUser("admin", passwordHash, true, true, true, false, "local") + require.NoError(t, err) + target, err := sqliteStore.AddUser("reset", passwordHash, false, false, true, false, "local") + require.NoError(t, err) + + form := url.Values{} + form.Set("password", "newpassword") + rec := httptest.NewRecorder() + req := newAuthenticatedFormRequest(t, server, sessionManager, http.MethodPost, "/admin/users/"+intToString(target.ID)+"/password/", form, "admin") + req.RemoteAddr = "192.0.2.9:1234" + serveProtected(server, rec, req, server.handleUserPassword) + + require.Equal(t, http.StatusSeeOther, rec.Code) + events, err := sqliteStore.ListAuditEvents() + require.NoError(t, err) + require.Len(t, events, 1) + require.Equal(t, "admin", events[0].Actor) + require.Equal(t, "reset", events[0].TargetUser) + require.Equal(t, "password_reset", events[0].Action) + require.Equal(t, "192.0.2.9", events[0].IPAddress) +} + +func TestUserEditForceResetLogsAuditEvent(t *testing.T) { + codec := testCodec(t) + sqliteStore := newTestSQLiteStore(t, codec) + server, sessionManager := newTestServerWithStore(t, sqliteStore) + passwordHash := hashPasswordForTest(t, "password") + _, err := sqliteStore.AddUser("admin", passwordHash, true, true, true, false, "local") + require.NoError(t, err) + target, err := sqliteStore.AddUser("target", passwordHash, false, false, true, false, "local") + require.NoError(t, err) + + form := url.Values{} + form.Set("username", "target") + form.Set("must_reset_password", "on") + form.Set("local_login_enabled", "on") + form.Set("auth_source", "local") + rec := httptest.NewRecorder() + req := newAuthenticatedFormRequest(t, server, sessionManager, http.MethodPost, "/admin/users/"+intToString(target.ID)+"/edit/", form, "admin") + req.RemoteAddr = "198.51.100.10:9999" + serveProtected(server, rec, req, server.handleUserEdit) + + require.Equal(t, http.StatusSeeOther, rec.Code) + events, err := sqliteStore.ListAuditEvents() + require.NoError(t, err) + require.Len(t, events, 1) + require.Equal(t, "force_reset_enabled", events[0].Action) + require.Equal(t, "198.51.100.10", events[0].IPAddress) +} + +func TestIDFromPath(t *testing.T) { + id, err := idFromPath("/info/", "/info/123/") + require.NoError(t, err) + require.Equal(t, 123, id) + + _, err = idFromPath("/info/", "/other/123/") + require.Error(t, err) +} + +func TestAuditLogSearch(t *testing.T) { + codec := testCodec(t) + sqliteStore := newTestSQLiteStore(t, codec) + server, sessionManager := newTestServerWithStore(t, sqliteStore) + passwordHash := hashPasswordForTest(t, "password") + _, err := sqliteStore.AddUser("admin", passwordHash, true, true, true, false, "local") + require.NoError(t, err) + _, err = sqliteStore.AddAuditEvent("admin", "user1", "password_reset", "reason", "127.0.0.1") + require.NoError(t, err) + _, err = sqliteStore.AddAuditEvent("admin", "user2", "user_deleted", "", "127.0.0.1") + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/admin/audit/?q=reset", nil, "admin") + serveProtected(server, rec, req, server.handleAuditLog) + + require.Equal(t, http.StatusOK, rec.Code) + require.Contains(t, rec.Body.String(), "password_reset") + require.NotContains(t, rec.Body.String(), "user_deleted") +} + +func TestAuditLogPaginationUsesPageParam(t *testing.T) { + storeSpy := &auditPaginationStore{} + server, sessionManager := newTestServerWithStore(t, storeSpy) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/admin/audit/?page=2", nil, "admin") + serveProtected(server, rec, req, server.handleAuditLog) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, 50, storeSpy.lastLimit) + require.Equal(t, 50, storeSpy.lastOffset) +} + +func TestClientIP(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.RemoteAddr = "203.0.113.5:4444" + require.Equal(t, "203.0.113.5", clientIP(req)) + + req.Header.Set("X-Real-IP", "192.0.2.1") + require.Equal(t, "192.0.2.1", clientIP(req)) + + req.Header.Set("X-Forwarded-For", "198.51.100.1, 203.0.113.9") + require.Equal(t, "198.51.100.1", clientIP(req)) +} + +func TestLookupComputer(t *testing.T) { + server, memStore, _ := newTestServer(t) + computer, err := memStore.AddComputer("SERIAL6", "user", "Mac Studio") + require.NoError(t, err) + + byID, err := server.lookupComputer(intToString(computer.ID)) + require.NoError(t, err) + require.Equal(t, "SERIAL6", byID.Serial) + + bySerial, err := server.lookupComputer("serial6") + require.NoError(t, err) + require.Equal(t, computer.ID, bySerial.ID) +} + +func TestCheckinCreatesSecret(t *testing.T) { + server, memStore, _ := newTestServer(t) + + form := url.Values{} + form.Set("serial", "SERIALCHECKIN") + form.Set("recovery_password", "secret-value") + form.Set("username", "user1") + form.Set("macname", "MacBook") + form.Set("secret_type", "recovery_key") + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/checkin/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + server.handleCheckin(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Contains(t, rec.Body.String(), "\"serial\":\"SERIALCHECKIN\"") + require.Contains(t, rec.Body.String(), "\"rotation_required\":false") + require.Contains(t, rec.Body.String(), "\"new_secret_escrowed\":true") + + computer, err := memStore.GetComputerBySerial("SERIALCHECKIN") + require.NoError(t, err) + require.Equal(t, "user1", computer.Username) +} + +func TestVerifyEscrowed(t *testing.T) { + server, memStore, _ := newTestServer(t) + computer, err := memStore.AddComputer("SERIALVERIFY", "user", "Mac") + require.NoError(t, err) + _, _, err = memStore.AddSecret(computer.ID, "recovery_key", "secret", false) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/verify/SERIALVERIFY/recovery_key/", nil) + server.handleVerify(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Contains(t, rec.Body.String(), "\"escrowed\":true") + require.Contains(t, rec.Body.String(), "\"date_escrowed\"") +} + +func TestVerifyNotEscrowed(t *testing.T) { + server, memStore, _ := newTestServer(t) + _, err := memStore.AddComputer("SERIALVERIFY2", "user", "Mac") + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/verify/SERIALVERIFY2/recovery_key/", nil) + server.handleVerify(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Contains(t, rec.Body.String(), "\"escrowed\":false") +} + +func TestVerifyMissingComputer(t *testing.T) { + server, _, _ := newTestServer(t) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/verify/UNKNOWN/recovery_key/", nil) + server.handleVerify(rec, req) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestHandleLoginSuccess(t *testing.T) { + server, _, _ := newTestServer(t) + + form := url.Values{} + form.Set("username", "admin") + form.Set("password", "password") + form.Set("next", "/") + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/login/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + server.handleLogin(rec, req) + + require.Equal(t, http.StatusSeeOther, rec.Code) + require.NotEmpty(t, rec.Header().Get("Set-Cookie")) +} + +func TestHandleLoginFailure(t *testing.T) { + server, _, _ := newTestServer(t) + + form := url.Values{} + form.Set("username", "admin") + form.Set("password", "wrong") + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/login/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + server.handleLogin(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Contains(t, rec.Body.String(), "Invalid username or password.") +} + +func TestHandleLoginRequiresLocalLogin(t *testing.T) { + server, memStore, _ := newTestServer(t) + passwordHash := hashPasswordForTest(t, "password") + _, err := memStore.AddUser("samluser", passwordHash, false, false, false, false, "saml") + require.NoError(t, err) + + form := url.Values{} + form.Set("username", "samluser") + form.Set("password", "password") + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/login/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + server.handleLogin(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Contains(t, rec.Body.String(), "Invalid username or password.") +} + +func TestHandleLoginRedirectsToReset(t *testing.T) { + server, memStore, _ := newTestServer(t) + passwordHash := hashPasswordForTest(t, "password") + _, err := memStore.AddUser("resetme", passwordHash, false, false, true, true, "local") + require.NoError(t, err) + + form := url.Values{} + form.Set("username", "resetme") + form.Set("password", "password") + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/login/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + server.handleLogin(rec, req) + + require.Equal(t, http.StatusSeeOther, rec.Code) + require.Equal(t, "/password/reset/", rec.Header().Get("Location")) +} + +func TestHandlePasswordChange(t *testing.T) { + server, _, sessionManager := newTestServer(t) + + form := url.Values{} + form.Set("current_password", "password") + form.Set("new_password", "Str0ng!Passw0rd") + rec := httptest.NewRecorder() + req := newAuthenticatedFormRequest(t, server, sessionManager, http.MethodPost, "/password/change/", form, "admin") + serveProtected(server, rec, req, server.handlePasswordChange) + + require.Equal(t, http.StatusSeeOther, rec.Code) + require.Equal(t, "/", rec.Header().Get("Location")) +} + +func TestHandlePasswordResetClearsFlag(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + passwordHash := hashPasswordForTest(t, "password") + resetUser, err := memStore.AddUser("resetuser", passwordHash, false, false, true, true, "local") + require.NoError(t, err) + + form := url.Values{} + form.Set("new_password", "Str0ng!Passw0rd") + rec := httptest.NewRecorder() + req := newAuthenticatedFormRequest(t, server, sessionManager, http.MethodPost, "/password/reset/", form, "resetuser") + serveProtected(server, rec, req, server.handlePasswordReset) + + require.Equal(t, http.StatusSeeOther, rec.Code) + updated, err := memStore.GetUserByID(resetUser.ID) + require.NoError(t, err) + require.False(t, updated.MustResetPassword) + require.True(t, verifyPassword("Str0ng!Passw0rd", updated.PasswordHash)) +} + +func TestHandleSAMLLoginStub(t *testing.T) { + server, _, _ := newTestServer(t) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/saml/login/", nil) + server.handleSAMLLogin(rec, req) + require.Equal(t, http.StatusNotImplemented, rec.Code) +} + +func TestHandleUserListRequiresStaff(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + passwordHash := hashPasswordForTest(t, "password") + _, err := memStore.AddUser("viewer", passwordHash, false, false, true, false, "local") + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/admin/users/", nil, "viewer") + serveProtected(server, rec, req, server.handleUserList) + require.Equal(t, http.StatusForbidden, rec.Code) +} + +func TestHandleUserList(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + passwordHash := hashPasswordForTest(t, "password") + _, err := memStore.AddUser("second", passwordHash, false, false, true, false, "local") + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/admin/users/", nil, "admin") + serveProtected(server, rec, req, server.handleUserList) + require.Equal(t, http.StatusOK, rec.Code) + require.Contains(t, rec.Body.String(), "admin") + require.Contains(t, rec.Body.String(), "second") +} + +func TestHandleNewUser(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + + form := url.Values{} + form.Set("username", "newuser") + form.Set("password", "Str0ng!Passw0rd") + form.Set("local_login_enabled", "on") + form.Set("auth_source", "local") + form.Set("is_staff", "on") + form.Set("can_approve", "on") + rec := httptest.NewRecorder() + req := newAuthenticatedFormRequest(t, server, sessionManager, http.MethodPost, "/admin/users/new/", form, "admin") + serveProtected(server, rec, req, server.handleNewUser) + + require.Equal(t, http.StatusSeeOther, rec.Code) + user, err := memStore.GetUserByUsername("newuser") + require.NoError(t, err) + require.True(t, user.IsStaff) + require.True(t, user.CanApprove) + require.True(t, user.LocalLoginEnabled) + require.Equal(t, "local", user.AuthSource) + require.True(t, verifyPassword("Str0ng!Passw0rd", user.PasswordHash)) + events, err := memStore.ListAuditEvents() + require.NoError(t, err) + event, ok := findAuditEvent(events, "user_created") + require.True(t, ok) + require.Equal(t, "newuser", event.TargetUser) +} + +func TestHandleUserEdit(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + passwordHash := hashPasswordForTest(t, "password") + target, err := memStore.AddUser("editor", passwordHash, false, false, true, false, "local") + require.NoError(t, err) + + form := url.Values{} + form.Set("username", "updated") + form.Set("is_staff", "on") + form.Set("can_approve", "on") + form.Set("local_login_enabled", "on") + form.Set("must_reset_password", "on") + form.Set("auth_source", "saml") + rec := httptest.NewRecorder() + req := newAuthenticatedFormRequest(t, server, sessionManager, http.MethodPost, "/admin/users/"+intToString(target.ID)+"/edit/", form, "admin") + serveProtected(server, rec, req, server.handleUserEdit) + + require.Equal(t, http.StatusSeeOther, rec.Code) + updated, err := memStore.GetUserByID(target.ID) + require.NoError(t, err) + require.Equal(t, "updated", updated.Username) + require.True(t, updated.IsStaff) + require.True(t, updated.CanApprove) + require.True(t, updated.MustResetPassword) + require.Equal(t, "saml", updated.AuthSource) + events, err := memStore.ListAuditEvents() + require.NoError(t, err) + _, ok := findAuditEvent(events, "user_updated") + require.True(t, ok) + _, ok = findAuditEvent(events, "force_reset_enabled") + require.True(t, ok) +} + +func TestHandleUserPassword(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + passwordHash := hashPasswordForTest(t, "password") + target, err := memStore.AddUser("reset", passwordHash, false, false, true, false, "local") + require.NoError(t, err) + + form := url.Values{} + form.Set("password", "Str0ng!Passw0rd") + rec := httptest.NewRecorder() + req := newAuthenticatedFormRequest(t, server, sessionManager, http.MethodPost, "/admin/users/"+intToString(target.ID)+"/password/", form, "admin") + serveProtected(server, rec, req, server.handleUserPassword) + + require.Equal(t, http.StatusSeeOther, rec.Code) + updated, err := memStore.GetUserByID(target.ID) + require.NoError(t, err) + require.True(t, verifyPassword("Str0ng!Passw0rd", updated.PasswordHash)) + require.False(t, updated.MustResetPassword) +} + +func TestHandleUserDelete(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + passwordHash := hashPasswordForTest(t, "password") + target, err := memStore.AddUser("remove", passwordHash, false, false, true, false, "local") + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedFormRequest(t, server, sessionManager, http.MethodPost, "/admin/users/"+intToString(target.ID)+"/delete/", url.Values{}, "admin") + serveProtected(server, rec, req, server.handleUserDelete) + + require.Equal(t, http.StatusSeeOther, rec.Code) + _, err = memStore.GetUserByID(target.ID) + require.Error(t, err) + events, err := memStore.ListAuditEvents() + require.NoError(t, err) + event, ok := findAuditEvent(events, "user_deleted") + require.True(t, ok) + require.Equal(t, "remove", event.TargetUser) +} + +func TestHandleUserDeleteSelf(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + admin, err := memStore.GetUserByUsername("admin") + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedFormRequest(t, server, sessionManager, http.MethodPost, "/admin/users/"+intToString(admin.ID)+"/delete/", url.Values{}, "admin") + serveProtected(server, rec, req, server.handleUserDelete) + + require.Equal(t, http.StatusOK, rec.Code) + require.Contains(t, rec.Body.String(), "You cannot delete your own account.") +} + +func intToString(value int) string { + return strconv.Itoa(value) +} + +func findAuditEvent(events []*store.AuditEvent, action string) (*store.AuditEvent, bool) { + for _, event := range events { + if event.Action == action { + return event, true + } + } + return nil, false +} + +func newAuthenticatedRequest(t *testing.T, sessionManager *SessionManager, method, target string, body io.Reader, username string) *http.Request { + t.Helper() + req := httptest.NewRequest(method, target, body) + token, err := sessionManager.Create(username) + require.NoError(t, err) + req.AddCookie(&http.Cookie{Name: sessionManager.CookieName(), Value: token}) + return req +} + +func newAuthenticatedFormRequest(t *testing.T, server *Server, sessionManager *SessionManager, method, target string, form url.Values, username string) *http.Request { + t.Helper() + csrfToken, err := server.csrfManager.GenerateToken() + require.NoError(t, err) + form.Set("csrf_token", csrfToken) + req := newAuthenticatedRequest(t, sessionManager, method, target, strings.NewReader(form.Encode()), username) + req.AddCookie(&http.Cookie{Name: server.csrfManager.cookieName, Value: csrfToken}) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + return req +} + +func serveProtected(server *Server, rec *httptest.ResponseRecorder, req *http.Request, handler http.HandlerFunc) { + server.withCSRF(server.withUser(http.HandlerFunc(server.requireAuth(handler)))).ServeHTTP(rec, req) +} + +func hashPasswordForTest(t *testing.T, password string) string { + t.Helper() + hash, err := hashPassword(password) + require.NoError(t, err) + return hash +} + +func testCodec(t *testing.T) *crypto.AesGcmCodec { + t.Helper() + key := make([]byte, 32) + for i := range key { + key[i] = byte(i + 1) + } + encoded := base64.StdEncoding.EncodeToString(key) + codec, err := crypto.NewAesGcmCodecFromBase64Key(encoded) + require.NoError(t, err) + return codec +} + +// Additional comprehensive tests for missing coverage + +func TestHandleNewComputerGET(t *testing.T) { + server, _, sessionManager := newTestServer(t) + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/new/computer/", nil, "admin") + serveProtected(server, rec, req, server.handleNewComputer) + + require.Equal(t, http.StatusOK, rec.Code) + require.Contains(t, rec.Body.String(), "New Computer") + require.Contains(t, rec.Body.String(), "Serial") +} + +func TestHandleNewComputerPOSTValidation(t *testing.T) { + server, _, sessionManager := newTestServer(t) + form := url.Values{} + form.Set("serial", "") + form.Set("computername", "") + + rec := httptest.NewRecorder() + req := newAuthenticatedFormRequest(t, server, sessionManager, http.MethodPost, "/new/computer/", form, "admin") + serveProtected(server, rec, req, server.handleNewComputer) + + require.Equal(t, http.StatusOK, rec.Code) + require.Contains(t, rec.Body.String(), "Serial number and computer name are required.") +} + +func TestHandleNewSecretGET(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + computer, err := memStore.AddComputer("SERIAL_SECRET_GET", "user", "MacBook Pro") + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/new/secret/"+intToString(computer.ID)+"/", nil, "admin") + serveProtected(server, rec, req, server.handleNewSecret) + + require.Equal(t, http.StatusOK, rec.Code) + require.Contains(t, rec.Body.String(), "New Secret") + require.Contains(t, rec.Body.String(), "MacBook Pro") +} + +func TestHandleNewSecretPOST(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + computer, err := memStore.AddComputer("SERIAL_SECRET_POST", "user", "Mac") + require.NoError(t, err) + + form := url.Values{} + form.Set("secret_type", "recovery_key") + form.Set("secret", "test-recovery-key-123456") + form.Set("rotation_required", "on") + + rec := httptest.NewRecorder() + req := newAuthenticatedFormRequest(t, server, sessionManager, http.MethodPost, "/new/secret/"+intToString(computer.ID)+"/", form, "admin") + serveProtected(server, rec, req, server.handleNewSecret) + + require.Equal(t, http.StatusSeeOther, rec.Code) + require.Contains(t, rec.Header().Get("Location"), "/info/"+intToString(computer.ID)) + + secrets, err := memStore.ListSecretsByComputer(computer.ID) + require.NoError(t, err) + require.Len(t, secrets, 1) + require.Equal(t, "recovery_key", secrets[0].SecretType) + require.True(t, secrets[0].RotationRequired) +} + +func TestHandleNewSecretPOSTValidation(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + computer, err := memStore.AddComputer("SERIAL_SECRET_VALID", "user", "Mac") + require.NoError(t, err) + + form := url.Values{} + form.Set("secret_type", "") + form.Set("secret", "") + + rec := httptest.NewRecorder() + req := newAuthenticatedFormRequest(t, server, sessionManager, http.MethodPost, "/new/secret/"+intToString(computer.ID)+"/", form, "admin") + serveProtected(server, rec, req, server.handleNewSecret) + + require.Equal(t, http.StatusOK, rec.Code) + require.Contains(t, rec.Body.String(), "Secret type and value are required.") +} + +func TestHandleNewSecretInvalidComputer(t *testing.T) { + server, _, sessionManager := newTestServer(t) + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/new/secret/99999/", nil, "admin") + serveProtected(server, rec, req, server.handleNewSecret) + + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestHandleComputerInfo(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + computer, err := memStore.AddComputer("SERIAL_INFO", "testuser", "MacBook Air") + require.NoError(t, err) + _, _, err = memStore.AddSecret(computer.ID, "recovery_key", "test-secret", false) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/info/"+intToString(computer.ID)+"/", nil, "admin") + serveProtected(server, rec, req, server.handleComputerInfo) + + require.Equal(t, http.StatusOK, rec.Code) + require.Contains(t, rec.Body.String(), "SERIAL_INFO") + require.Contains(t, rec.Body.String(), "MacBook Air") + require.Contains(t, rec.Body.String(), "testuser") + require.Contains(t, rec.Body.String(), "Recovery Key") +} + +func TestHandleComputerInfoBySerial(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + _, err := memStore.AddComputer("SERIAL_BY_SERIAL", "user", "iMac") + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/info/serial_by_serial/", nil, "admin") + serveProtected(server, rec, req, server.handleComputerInfo) + + require.Equal(t, http.StatusOK, rec.Code) + require.Contains(t, rec.Body.String(), "SERIAL_BY_SERIAL") +} + +func TestHandleComputerInfoNotFound(t *testing.T) { + server, _, sessionManager := newTestServer(t) + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/info/99999/", nil, "admin") + serveProtected(server, rec, req, server.handleComputerInfo) + + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestHandleSecretInfo(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + computer, err := memStore.AddComputer("SERIAL_SECRET_INFO", "user", "Mac") + require.NoError(t, err) + secret, _, err := memStore.AddSecret(computer.ID, "password", "secret-pass", false) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/info/secret/"+intToString(secret.ID)+"/", nil, "admin") + serveProtected(server, rec, req, server.handleSecretInfo) + + require.Equal(t, http.StatusOK, rec.Code) + require.Contains(t, rec.Body.String(), "Secret Info") + require.Contains(t, rec.Body.String(), "SERIAL_SECRET_INFO") + require.Contains(t, rec.Body.String(), "Get Key") +} + +func TestHandleSecretInfoWithPendingRequestNonApprover(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + passwordHash := hashPasswordForTest(t, "password") + _, err := memStore.AddUser("viewer", passwordHash, false, false, true, false, "local") + require.NoError(t, err) + + computer, err := memStore.AddComputer("SERIAL_PENDING", "user", "Mac") + require.NoError(t, err) + secret, _, err := memStore.AddSecret(computer.ID, "recovery_key", "secret", false) + require.NoError(t, err) + _, err = memStore.AddRequest(secret.ID, "viewer", "Need access", "", nil) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/info/secret/"+intToString(secret.ID)+"/", nil, "viewer") + serveProtected(server, rec, req, server.handleSecretInfo) + + require.Equal(t, http.StatusOK, rec.Code) + require.NotContains(t, rec.Body.String(), "Request Key") + require.Contains(t, rec.Body.String(), "Request Pending") +} + +func TestHandleRequestGET(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + computer, err := memStore.AddComputer("SERIAL_REQUEST_GET", "user", "Mac") + require.NoError(t, err) + secret, _, err := memStore.AddSecret(computer.ID, "recovery_key", "secret", false) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/request/"+intToString(secret.ID)+"/", nil, "admin") + serveProtected(server, rec, req, server.handleRequest) + + require.Equal(t, http.StatusOK, rec.Code) + require.Contains(t, rec.Body.String(), "Request Secret") + require.Contains(t, rec.Body.String(), "SERIAL_REQUEST_GET") +} + +func TestHandleRequestPOSTWithoutAutoApprove(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + server.settings.ApproveOwn = false + computer, err := memStore.AddComputer("SERIAL_NO_AUTO", "user", "Mac") + require.NoError(t, err) + secret, _, err := memStore.AddSecret(computer.ID, "recovery_key", "secret", false) + require.NoError(t, err) + + form := url.Values{} + form.Set("reason_for_request", "Testing no auto-approve") + rec := httptest.NewRecorder() + req := newAuthenticatedFormRequest(t, server, sessionManager, http.MethodPost, "/request/"+intToString(secret.ID)+"/", form, "admin") + serveProtected(server, rec, req, server.handleRequest) + + require.Equal(t, http.StatusSeeOther, rec.Code) + requests, err := memStore.ListRequestsBySecret(secret.ID) + require.NoError(t, err) + require.Len(t, requests, 1) + require.Nil(t, requests[0].Approved) +} + +func TestHandleApproveGET(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + server.settings.ApproveOwn = false + computer, err := memStore.AddComputer("SERIAL_APPROVE_GET", "user", "Mac") + require.NoError(t, err) + secret, _, err := memStore.AddSecret(computer.ID, "recovery_key", "secret", false) + require.NoError(t, err) + request, err := memStore.AddRequest(secret.ID, "requester", "Need access", "", nil) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/approve/"+intToString(request.ID)+"/", nil, "admin") + serveProtected(server, rec, req, server.handleApprove) + + require.Equal(t, http.StatusOK, rec.Code) + require.Contains(t, rec.Body.String(), "Approve Request") + require.Contains(t, rec.Body.String(), "SERIAL_APPROVE_GET") + require.Contains(t, rec.Body.String(), "reason_for_approval") +} + +func TestHandleApprovePOSTApprove(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + server.settings.ApproveOwn = false + computer, err := memStore.AddComputer("SERIAL_APPROVE", "user", "Mac") + require.NoError(t, err) + secret, _, err := memStore.AddSecret(computer.ID, "recovery_key", "secret", false) + require.NoError(t, err) + request, err := memStore.AddRequest(secret.ID, "requester", "Need access", "", nil) + require.NoError(t, err) + + form := url.Values{} + form.Set("approved", "1") + form.Set("reason_for_approval", "Looks good") + rec := httptest.NewRecorder() + req := newAuthenticatedFormRequest(t, server, sessionManager, http.MethodPost, "/approve/"+intToString(request.ID)+"/", form, "admin") + serveProtected(server, rec, req, server.handleApprove) + + require.Equal(t, http.StatusSeeOther, rec.Code) + require.Equal(t, "/manage-requests/", rec.Header().Get("Location")) + + updated, err := memStore.GetRequestByID(request.ID) + require.NoError(t, err) + require.NotNil(t, updated.Approved) + require.True(t, *updated.Approved) + require.Equal(t, "admin", updated.AuthUser) + require.Equal(t, "Looks good", updated.ReasonForApproval) +} + +func TestHandleApprovePOSTDeny(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + server.settings.ApproveOwn = false + computer, err := memStore.AddComputer("SERIAL_DENY", "user", "Mac") + require.NoError(t, err) + secret, _, err := memStore.AddSecret(computer.ID, "recovery_key", "secret", false) + require.NoError(t, err) + request, err := memStore.AddRequest(secret.ID, "requester", "Need access", "", nil) + require.NoError(t, err) + + form := url.Values{} + form.Set("approved", "0") + form.Set("reason_for_approval", "Insufficient justification") + rec := httptest.NewRecorder() + req := newAuthenticatedFormRequest(t, server, sessionManager, http.MethodPost, "/approve/"+intToString(request.ID)+"/", form, "admin") + serveProtected(server, rec, req, server.handleApprove) + + require.Equal(t, http.StatusSeeOther, rec.Code) + updated, err := memStore.GetRequestByID(request.ID) + require.NoError(t, err) + require.NotNil(t, updated.Approved) + require.False(t, *updated.Approved) + require.Equal(t, "Insufficient justification", updated.ReasonForApproval) +} + +func TestHandleApproveWithoutPermission(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + passwordHash := hashPasswordForTest(t, "password") + _, err := memStore.AddUser("noapprove", passwordHash, false, false, true, false, "local") + require.NoError(t, err) + + computer, err := memStore.AddComputer("SERIAL_NO_PERM", "user", "Mac") + require.NoError(t, err) + secret, _, err := memStore.AddSecret(computer.ID, "recovery_key", "secret", false) + require.NoError(t, err) + request, err := memStore.AddRequest(secret.ID, "admin", "Need access", "", nil) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/approve/"+intToString(request.ID)+"/", nil, "noapprove") + serveProtected(server, rec, req, server.handleApprove) + + require.Equal(t, http.StatusForbidden, rec.Code) +} + +func TestHandleApproveSelfRequestWhenNotAllowed(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + server.settings.ApproveOwn = false + + computer, err := memStore.AddComputer("SERIAL_SELF", "user", "Mac") + require.NoError(t, err) + secret, _, err := memStore.AddSecret(computer.ID, "recovery_key", "secret", false) + require.NoError(t, err) + request, err := memStore.AddRequest(secret.ID, "admin", "Need access", "", nil) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/approve/"+intToString(request.ID)+"/", nil, "admin") + serveProtected(server, rec, req, server.handleApprove) + + require.Equal(t, http.StatusForbidden, rec.Code) +} + +func TestHandleLogout(t *testing.T) { + server, _, sessionManager := newTestServer(t) + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/logout/", nil, "admin") + server.handleLogout(rec, req) + + require.Equal(t, http.StatusSeeOther, rec.Code) + require.Equal(t, "/login/", rec.Header().Get("Location")) + + cookies := rec.Result().Cookies() + found := false + for _, cookie := range cookies { + if cookie.Name == sessionManager.CookieName() { + found = true + require.Equal(t, "", cookie.Value) + require.Equal(t, -1, cookie.MaxAge) + } + } + require.True(t, found) +} + +func TestHandleRetrieveUnapprovedRequest(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + computer, err := memStore.AddComputer("SERIAL_UNAPPROVED", "user", "Mac") + require.NoError(t, err) + secret, _, err := memStore.AddSecret(computer.ID, "recovery_key", "secret", false) + require.NoError(t, err) + request, err := memStore.AddRequest(secret.ID, "admin", "Need access", "", nil) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/retrieve/"+intToString(request.ID)+"/", nil, "admin") + serveProtected(server, rec, req, server.handleRetrieve) + + require.Equal(t, http.StatusForbidden, rec.Code) +} + +func TestHandleRetrieveDeniedRequest(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + computer, err := memStore.AddComputer("SERIAL_DENIED", "user", "Mac") + require.NoError(t, err) + secret, _, err := memStore.AddSecret(computer.ID, "recovery_key", "secret", false) + require.NoError(t, err) + approved := false + request, err := memStore.AddRequest(secret.ID, "admin", "Need access", "approver", &approved) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/retrieve/"+intToString(request.ID)+"/", nil, "admin") + serveProtected(server, rec, req, server.handleRetrieve) + + require.Equal(t, http.StatusForbidden, rec.Code) +} + +func TestHandleRetrieveWrongUser(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + passwordHash := hashPasswordForTest(t, "password") + _, err := memStore.AddUser("other", passwordHash, false, false, true, false, "local") + require.NoError(t, err) + + computer, err := memStore.AddComputer("SERIAL_WRONG_USER", "user", "Mac") + require.NoError(t, err) + secret, _, err := memStore.AddSecret(computer.ID, "recovery_key", "secret", false) + require.NoError(t, err) + approved := true + request, err := memStore.AddRequest(secret.ID, "admin", "Need access", "approver", &approved) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/retrieve/"+intToString(request.ID)+"/", nil, "other") + serveProtected(server, rec, req, server.handleRetrieve) + + require.Equal(t, http.StatusForbidden, rec.Code) +} + +func TestHandleManageRequestsRequiresApprover(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + passwordHash := hashPasswordForTest(t, "password") + _, err := memStore.AddUser("viewer", passwordHash, false, false, true, false, "local") + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/manage-requests/", nil, "viewer") + serveProtected(server, rec, req, server.handleManageRequests) + + require.Equal(t, http.StatusForbidden, rec.Code) +} + +func TestRequireAuthRedirectsToLogin(t *testing.T) { + server, _, _ := newTestServer(t) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + + server.withUser(http.HandlerFunc(server.requireAuth(server.handleIndex))).ServeHTTP(rec, req) + + require.Equal(t, http.StatusSeeOther, rec.Code) + require.Contains(t, rec.Header().Get("Location"), "/login/") + require.Contains(t, rec.Header().Get("Location"), "next=%2F") +} + +func TestCSRFProtectionBlocksInvalidToken(t *testing.T) { + server, _, sessionManager := newTestServer(t) + form := url.Values{} + form.Set("serial", "TEST") + form.Set("username", "user") + form.Set("computername", "Mac") + form.Set("csrf_token", "invalid-token") + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodPost, "/new/computer/", strings.NewReader(form.Encode()), "admin") + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + server.withCSRF(server.withUser(http.HandlerFunc(server.requireAuth(server.handleNewComputer)))).ServeHTTP(rec, req) + + require.Equal(t, http.StatusForbidden, rec.Code) + require.Contains(t, rec.Body.String(), "CSRF token missing or invalid") +} + +func TestCheckinDefaultsSecretType(t *testing.T) { + server, memStore, _ := newTestServer(t) + + form := url.Values{} + form.Set("serial", "SERIAL_DEFAULT_TYPE") + form.Set("recovery_password", "secret-value") + form.Set("username", "user1") + form.Set("macname", "MacBook") + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/checkin/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + server.handleCheckin(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + computer, err := memStore.GetComputerBySerial("SERIAL_DEFAULT_TYPE") + require.NoError(t, err) + secret, err := memStore.GetLatestSecretByComputerAndType(computer.ID, "recovery_key") + require.NoError(t, err) + require.NotNil(t, secret) +} + +func TestCheckinMissingRequiredFields(t *testing.T) { + server, _, _ := newTestServer(t) + + form := url.Values{} + form.Set("serial", "") + form.Set("recovery_password", "") + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/checkin/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + server.handleCheckin(rec, req) + + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +func TestVerifyInvalidPath(t *testing.T) { + server, _, _ := newTestServer(t) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/verify/invalid/", nil) + server.handleVerify(rec, req) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestAllApproveSettingGrantsApprovalPermission(t *testing.T) { + server, memStore, sessionManager := newTestServer(t) + server.settings.AllApprove = true + passwordHash := hashPasswordForTest(t, "password") + _, err := memStore.AddUser("regularuser", passwordHash, false, false, true, false, "local") + require.NoError(t, err) + + computer, err := memStore.AddComputer("SERIAL_ALL_APPROVE", "user", "Mac") + require.NoError(t, err) + secret, _, err := memStore.AddSecret(computer.ID, "recovery_key", "secret", false) + require.NoError(t, err) + request, err := memStore.AddRequest(secret.ID, "admin", "Need access", "", nil) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := newAuthenticatedRequest(t, sessionManager, http.MethodGet, "/approve/"+intToString(request.ID)+"/", nil, "regularuser") + serveProtected(server, rec, req, server.handleApprove) + + require.Equal(t, http.StatusOK, rec.Code) +} + +func TestCheckinNewSecretEscrowedFirstTime(t *testing.T) { + server, memStore, _ := newTestServer(t) + + form := url.Values{} + form.Set("serial", "SERIAL_NEW_SECRET") + form.Set("recovery_password", "first-secret-value") + form.Set("username", "user1") + form.Set("macname", "MacBook") + form.Set("secret_type", "recovery_key") + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/checkin/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + server.handleCheckin(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var data map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &data)) + require.Equal(t, "SERIAL_NEW_SECRET", data["serial"]) + require.Equal(t, true, data["new_secret_escrowed"]) + + computer, err := memStore.GetComputerBySerial("SERIAL_NEW_SECRET") + require.NoError(t, err) + secrets, err := memStore.ListSecretsByComputer(computer.ID) + require.NoError(t, err) + require.Len(t, secrets, 1) +} + +func TestCheckinNewSecretEscrowedDuplicate(t *testing.T) { + server, memStore, _ := newTestServer(t) + + // First escrow + form := url.Values{} + form.Set("serial", "SERIAL_DUPLICATE") + form.Set("recovery_password", "same-secret-value") + form.Set("username", "user1") + form.Set("macname", "MacBook") + form.Set("secret_type", "recovery_key") + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/checkin/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + server.handleCheckin(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + var firstData map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &firstData)) + require.Equal(t, true, firstData["new_secret_escrowed"]) + + // Second escrow with same secret + rec2 := httptest.NewRecorder() + req2 := httptest.NewRequest(http.MethodPost, "/checkin/", strings.NewReader(form.Encode())) + req2.Header.Set("Content-Type", "application/x-www-form-urlencoded") + server.handleCheckin(rec2, req2) + require.Equal(t, http.StatusOK, rec2.Code) + + var secondData map[string]any + require.NoError(t, json.Unmarshal(rec2.Body.Bytes(), &secondData)) + require.Equal(t, false, secondData["new_secret_escrowed"]) + + // Verify only one secret was created + computer, err := memStore.GetComputerBySerial("SERIAL_DUPLICATE") + require.NoError(t, err) + secrets, err := memStore.ListSecretsByComputer(computer.ID) + require.NoError(t, err) + require.Len(t, secrets, 1) +} + +func TestCheckinNewSecretEscrowedDifferentKey(t *testing.T) { + server, memStore, _ := newTestServer(t) + + // First escrow + form := url.Values{} + form.Set("serial", "SERIAL_DIFF_KEY") + form.Set("recovery_password", "first-secret") + form.Set("username", "user1") + form.Set("macname", "MacBook") + form.Set("secret_type", "recovery_key") + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/checkin/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + server.handleCheckin(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + var firstData map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &firstData)) + require.Equal(t, true, firstData["new_secret_escrowed"]) + + // Second escrow with different secret value + form.Set("recovery_password", "second-different-secret") + rec2 := httptest.NewRecorder() + req2 := httptest.NewRequest(http.MethodPost, "/checkin/", strings.NewReader(form.Encode())) + req2.Header.Set("Content-Type", "application/x-www-form-urlencoded") + server.handleCheckin(rec2, req2) + require.Equal(t, http.StatusOK, rec2.Code) + + var secondData map[string]any + require.NoError(t, json.Unmarshal(rec2.Body.Bytes(), &secondData)) + require.Equal(t, true, secondData["new_secret_escrowed"]) + + // Verify two secrets were created + computer, err := memStore.GetComputerBySerial("SERIAL_DIFF_KEY") + require.NoError(t, err) + secrets, err := memStore.ListSecretsByComputer(computer.ID) + require.NoError(t, err) + require.Len(t, secrets, 2) +} diff --git a/internal/app/password.go b/internal/app/password.go new file mode 100644 index 0000000..ff8e851 --- /dev/null +++ b/internal/app/password.go @@ -0,0 +1,79 @@ +package app + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "errors" + "fmt" + "strings" + + "golang.org/x/crypto/argon2" +) + +const ( + argon2Time = 1 + argon2Memory = 64 * 1024 + argon2Threads = 4 + argon2KeyLen = 32 + argon2SaltLen = 16 +) + +func hashPassword(plaintext string) (string, error) { + if plaintext == "" { + return "", errors.New("password is required") + } + salt := make([]byte, argon2SaltLen) + if _, err := rand.Read(salt); err != nil { + return "", fmt.Errorf("generate salt: %w", err) + } + hash := argon2.IDKey([]byte(plaintext), salt, argon2Time, argon2Memory, argon2Threads, argon2KeyLen) + return fmt.Sprintf("$argon2id$%d$%d$%d$%s$%s", + argon2Time, + argon2Memory, + argon2Threads, + base64.RawStdEncoding.EncodeToString(salt), + base64.RawStdEncoding.EncodeToString(hash), + ), nil +} + +func verifyPassword(plaintext, encoded string) bool { + params, salt, hash, err := parseArgon2id(encoded) + if err != nil { + return false + } + check := argon2.IDKey([]byte(plaintext), salt, params.time, params.memory, params.threads, uint32(len(hash))) + return subtle.ConstantTimeCompare(check, hash) == 1 +} + +type argon2Params struct { + time uint32 + memory uint32 + threads uint8 +} + +func parseArgon2id(encoded string) (argon2Params, []byte, []byte, error) { + parts := strings.Split(encoded, "$") + if len(parts) != 7 || parts[1] != "argon2id" { + return argon2Params{}, nil, nil, errors.New("invalid argon2id hash") + } + var params argon2Params + if _, err := fmt.Sscanf(parts[2], "%d", ¶ms.time); err != nil { + return argon2Params{}, nil, nil, errors.New("invalid argon2id time") + } + if _, err := fmt.Sscanf(parts[3], "%d", ¶ms.memory); err != nil { + return argon2Params{}, nil, nil, errors.New("invalid argon2id memory") + } + if _, err := fmt.Sscanf(parts[4], "%d", ¶ms.threads); err != nil { + return argon2Params{}, nil, nil, errors.New("invalid argon2id threads") + } + salt, err := base64.RawStdEncoding.DecodeString(parts[5]) + if err != nil { + return argon2Params{}, nil, nil, errors.New("invalid argon2id salt") + } + hash, err := base64.RawStdEncoding.DecodeString(parts[6]) + if err != nil { + return argon2Params{}, nil, nil, errors.New("invalid argon2id hash") + } + return params, salt, hash, nil +} diff --git a/internal/app/password_export.go b/internal/app/password_export.go new file mode 100644 index 0000000..498a5f5 --- /dev/null +++ b/internal/app/password_export.go @@ -0,0 +1,6 @@ +package app + +// HashPassword exposes the Argon2id hash for CLI utilities. +func HashPassword(plaintext string) (string, error) { + return hashPassword(plaintext) +} diff --git a/internal/app/password_test.go b/internal/app/password_test.go new file mode 100644 index 0000000..c950a74 --- /dev/null +++ b/internal/app/password_test.go @@ -0,0 +1,25 @@ +package app + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPasswordHashAndVerify(t *testing.T) { + hash, err := hashPassword("secret") + require.NoError(t, err) + require.True(t, verifyPassword("secret", hash)) + require.False(t, verifyPassword("wrong", hash)) +} + +func TestPasswordHashRejectsEmpty(t *testing.T) { + _, err := hashPassword("") + require.Error(t, err) +} + +func TestHashPasswordExported(t *testing.T) { + hash, err := HashPassword("secret") + require.NoError(t, err) + require.Contains(t, hash, "$argon2id$") +} diff --git a/internal/app/renderer.go b/internal/app/renderer.go new file mode 100644 index 0000000..be71147 --- /dev/null +++ b/internal/app/renderer.go @@ -0,0 +1,45 @@ +package app + +import ( + "fmt" + "html/template" + "net/http" + "path/filepath" +) + +type Renderer struct { + baseLayout string + pageDir string + cache map[string]*template.Template +} + +func NewRenderer(baseLayout, pageDir string) *Renderer { + return &Renderer{ + baseLayout: baseLayout, + pageDir: pageDir, + cache: make(map[string]*template.Template), + } +} + +func (r *Renderer) Render(w http.ResponseWriter, name string, data any) error { + page, ok := r.cache[name] + if !ok { + layout := r.baseLayout + pagePath := filepath.Join(r.pageDir, name+".html") + parsed, err := template.New("base").Funcs(template.FuncMap{ + "add": func(a, b int) int { return a + b }, + "sub": func(a, b int) int { return a - b }, + }).ParseFiles(layout, pagePath) + if err != nil { + return fmt.Errorf("parse templates: %w", err) + } + r.cache[name] = parsed + page = parsed + } + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + if err := page.ExecuteTemplate(w, "base", data); err != nil { + return fmt.Errorf("render template: %w", err) + } + return nil +} diff --git a/internal/app/renderer_test.go b/internal/app/renderer_test.go new file mode 100644 index 0000000..f2b022e --- /dev/null +++ b/internal/app/renderer_test.go @@ -0,0 +1,24 @@ +package app + +import ( + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRendererRender(t *testing.T) { + tmp := t.TempDir() + layout := filepath.Join(tmp, "base.html") + page := filepath.Join(tmp, "page.html") + + require.NoError(t, os.WriteFile(layout, []byte("{{define \"base\"}}Hello {{block \"content\" .}}{{end}}{{end}}"), 0o600)) + require.NoError(t, os.WriteFile(page, []byte("{{define \"content\"}}World{{end}}"), 0o600)) + + renderer := NewRenderer(layout, tmp) + recorder := httptest.NewRecorder() + require.NoError(t, renderer.Render(recorder, "page", TemplateData{})) + require.Equal(t, "Hello World", recorder.Body.String()) +} diff --git a/internal/app/saml.go b/internal/app/saml.go new file mode 100644 index 0000000..29c5b61 --- /dev/null +++ b/internal/app/saml.go @@ -0,0 +1,179 @@ +package app + +import ( + "context" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "net/http" + "net/url" + "os" + "strings" + + "github.com/crewjam/saml" + "github.com/crewjam/saml/samlsp" +) + +func BuildSAMLProvider(cfg *SAMLConfig) (*samlsp.Middleware, error) { + rootURL, err := url.Parse(cfg.RootURL) + if err != nil { + return nil, fmt.Errorf("parse saml root url: %w", err) + } + + idpMetadata, err := loadIDPMetadata(cfg) + if err != nil { + return nil, err + } + + entityID := cfg.EntityID + if entityID == "" { + entityID = rootURL.ResolveReference(&url.URL{Path: cfg.MetadataURLPath}).String() + } + + opts := samlsp.Options{ + EntityID: entityID, + URL: *rootURL, + IDPMetadata: idpMetadata, + AllowIDPInitiated: cfg.AllowIDPInitiated, + DefaultRedirectURI: cfg.DefaultRedirectURI, + SignRequest: cfg.SignRequest, + } + + // Load certificate and private key if provided + if cfg.CertificatePath != "" && cfg.PrivateKeyPath != "" { + keyPair, err := tls.LoadX509KeyPair(cfg.CertificatePath, cfg.PrivateKeyPath) + if err != nil { + return nil, fmt.Errorf("load saml keypair: %w", err) + } + + cert, err := parseX509Certificate(keyPair.Certificate) + if err != nil { + return nil, fmt.Errorf("parse saml certificate: %w", err) + } + + privateKey, ok := keyPair.PrivateKey.(*rsa.PrivateKey) + if !ok { + return nil, errors.New("saml private key must be RSA") + } + + opts.Key = privateKey + opts.Certificate = cert + } + + middleware, err := samlsp.New(opts) + if err != nil { + return nil, fmt.Errorf("init saml: %w", err) + } + + metadataURL := rootURL.ResolveReference(&url.URL{Path: cfg.MetadataURLPath}) + acsURL := rootURL.ResolveReference(&url.URL{Path: cfg.AcsURLPath}) + sloURL := rootURL.ResolveReference(&url.URL{Path: cfg.SloURLPath}) + middleware.ServiceProvider.MetadataURL = *metadataURL + middleware.ServiceProvider.AcsURL = *acsURL + middleware.ServiceProvider.SloURL = *sloURL + + return middleware, nil +} + +func loadIDPMetadata(cfg *SAMLConfig) (*saml.EntityDescriptor, error) { + if cfg.IDPMetadataPath != "" { + data, err := os.ReadFile(cfg.IDPMetadataPath) + if err != nil { + return nil, fmt.Errorf("read idp metadata: %w", err) + } + metadata, err := samlsp.ParseMetadata(data) + if err != nil { + return nil, fmt.Errorf("parse idp metadata: %w", err) + } + return metadata, nil + } + metadataURL, err := url.Parse(cfg.IDPMetadataURL) + if err != nil { + return nil, fmt.Errorf("parse idp metadata url: %w", err) + } + metadata, err := samlsp.FetchMetadata(context.Background(), http.DefaultClient, *metadataURL) + if err != nil { + return nil, fmt.Errorf("fetch idp metadata: %w", err) + } + return metadata, nil +} + +func parseX509Certificate(certs [][]byte) (*x509.Certificate, error) { + if len(certs) == 0 { + return nil, errors.New("missing certificate") + } + return x509.ParseCertificate(certs[0]) +} + +func usernameFromSAML(session samlsp.Session, cfg *SAMLConfig) string { + if cfg.UseNameIDAsUsername { + if claims, ok := session.(samlsp.JWTSessionClaims); ok { + if claims.Subject != "" { + return claims.Subject + } + } + } + attributes := attributesFromSession(session) + if cfg.UsernameAttribute != "" { + if value := attributes.Get(cfg.UsernameAttribute); value != "" { + return value + } + } + for key, mapped := range cfg.AttributeMapping { + if mapped == "username" { + if value := attributes.Get(key); value != "" { + return value + } + } + } + if value := attributes.Get("uid"); value != "" { + return value + } + return "" +} + +func attributesFromSession(session samlsp.Session) samlsp.Attributes { + if session == nil { + return nil + } + if withAttrs, ok := session.(samlsp.SessionWithAttributes); ok { + return withAttrs.GetAttributes() + } + return nil +} + +func groupMembership(attributes samlsp.Attributes, attr string) []string { + if attributes == nil { + return nil + } + return attributes[attr] +} + +func matchesGroup(groups []string, target []string) bool { + for _, candidate := range groups { + for _, group := range target { + if strings.EqualFold(candidate, group) { + return true + } + } + } + return false +} + +func resolveSAMLPermissions(groups []string, cfg *SAMLConfig) (bool, bool) { + isStaff := false + canApprove := false + if matchesGroup(groups, cfg.SuperuserGroups) { + isStaff = true + canApprove = true + } + if matchesGroup(groups, cfg.StaffGroups) { + isStaff = true + } + if matchesGroup(groups, cfg.CanApproveGroups) { + canApprove = true + } + return isStaff, canApprove +} diff --git a/internal/app/saml_config.go b/internal/app/saml_config.go new file mode 100644 index 0000000..944306e --- /dev/null +++ b/internal/app/saml_config.go @@ -0,0 +1,76 @@ +package app + +import ( + "errors" + "fmt" + "os" + + "gopkg.in/yaml.v3" +) + +type SAMLConfig struct { + RootURL string `yaml:"root_url"` + EntityID string `yaml:"entity_id"` + IDPMetadataPath string `yaml:"idp_metadata_path"` + IDPMetadataURL string `yaml:"idp_metadata_url"` + CertificatePath string `yaml:"certificate_path"` + PrivateKeyPath string `yaml:"private_key_path"` + AllowIDPInitiated bool `yaml:"allow_idp_initiated"` + SignRequest bool `yaml:"sign_request"` + UseNameIDAsUsername bool `yaml:"use_name_id_as_username"` + CreateUnknownUser bool `yaml:"create_unknown_user"` + UsernameAttribute string `yaml:"username_attribute"` + AttributeMapping map[string]string `yaml:"attribute_mapping"` + GroupsAttribute string `yaml:"groups_attribute"` + StaffGroups []string `yaml:"staff_groups"` + SuperuserGroups []string `yaml:"superuser_groups"` + CanApproveGroups []string `yaml:"can_approve_groups"` + DefaultAuthSource string `yaml:"auth_source"` + DefaultLocalLogin bool `yaml:"local_login_enabled"` + DefaultMustReset bool `yaml:"must_reset_password"` + DefaultRedirectURI string `yaml:"default_redirect_uri"` + MetadataURLPath string `yaml:"metadata_url_path"` + AcsURLPath string `yaml:"acs_url_path"` + SloURLPath string `yaml:"slo_url_path"` +} + +func LoadSAMLConfig(path string) (*SAMLConfig, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read saml config: %w", err) + } + var cfg SAMLConfig + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("parse saml yaml: %w", err) + } + if cfg.RootURL == "" { + return nil, errors.New("saml config missing root_url") + } + if cfg.IDPMetadataPath == "" && cfg.IDPMetadataURL == "" { + return nil, errors.New("saml config missing idp metadata path or url") + } + // Certificate and private key are optional - only needed if sign_request is true + // or if the IdP encrypts assertions + if cfg.SignRequest && (cfg.CertificatePath == "" || cfg.PrivateKeyPath == "") { + return nil, errors.New("saml config requires certificate and private key when sign_request is enabled") + } + if cfg.GroupsAttribute == "" { + cfg.GroupsAttribute = "memberOf" + } + if cfg.DefaultAuthSource == "" { + cfg.DefaultAuthSource = "saml" + } + if cfg.DefaultRedirectURI == "" { + cfg.DefaultRedirectURI = "/" + } + if cfg.MetadataURLPath == "" { + cfg.MetadataURLPath = "/saml2/metadata/" + } + if cfg.AcsURLPath == "" { + cfg.AcsURLPath = "/saml2/acs/" + } + if cfg.SloURLPath == "" { + cfg.SloURLPath = "/saml2/ls/" + } + return &cfg, nil +} diff --git a/internal/app/saml_config_test.go b/internal/app/saml_config_test.go new file mode 100644 index 0000000..aed09d8 --- /dev/null +++ b/internal/app/saml_config_test.go @@ -0,0 +1,67 @@ +package app + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLoadSAMLConfigDefaults(t *testing.T) { + tmp := t.TempDir() + cfgPath := filepath.Join(tmp, "saml.yaml") + err := os.WriteFile(cfgPath, []byte(`root_url: https://crypt.example.com +idp_metadata_path: /tmp/metadata.xml +`), 0o600) + require.NoError(t, err) + + cfg, err := LoadSAMLConfig(cfgPath) + require.NoError(t, err) + require.Equal(t, "/saml2/metadata/", cfg.MetadataURLPath) + require.Equal(t, "/saml2/acs/", cfg.AcsURLPath) + require.Equal(t, "/saml2/ls/", cfg.SloURLPath) + require.Equal(t, "memberOf", cfg.GroupsAttribute) + require.Equal(t, "saml", cfg.DefaultAuthSource) + require.Equal(t, "/", cfg.DefaultRedirectURI) +} + +func TestLoadSAMLConfigWithoutCertificates(t *testing.T) { + tmp := t.TempDir() + cfgPath := filepath.Join(tmp, "saml.yaml") + // Config without certificate/private_key - should work when sign_request is false + err := os.WriteFile(cfgPath, []byte(`root_url: https://crypt.example.com +idp_metadata_url: https://idp.example.com/metadata +`), 0o600) + require.NoError(t, err) + + cfg, err := LoadSAMLConfig(cfgPath) + require.NoError(t, err) + require.Empty(t, cfg.CertificatePath) + require.Empty(t, cfg.PrivateKeyPath) +} + +func TestLoadSAMLConfigSignRequestRequiresCertificates(t *testing.T) { + tmp := t.TempDir() + cfgPath := filepath.Join(tmp, "saml.yaml") + // sign_request: true requires certificate and private key + err := os.WriteFile(cfgPath, []byte(`root_url: https://crypt.example.com +idp_metadata_url: https://idp.example.com/metadata +sign_request: true +`), 0o600) + require.NoError(t, err) + + _, err = LoadSAMLConfig(cfgPath) + require.Error(t, err) + require.Contains(t, err.Error(), "certificate and private key") +} + +func TestLoadSAMLConfigRequiresFields(t *testing.T) { + tmp := t.TempDir() + cfgPath := filepath.Join(tmp, "saml.yaml") + err := os.WriteFile(cfgPath, []byte(`root_url: ""`), 0o600) + require.NoError(t, err) + + _, err = LoadSAMLConfig(cfgPath) + require.Error(t, err) +} diff --git a/internal/app/saml_test.go b/internal/app/saml_test.go new file mode 100644 index 0000000..e3fba71 --- /dev/null +++ b/internal/app/saml_test.go @@ -0,0 +1,56 @@ +package app + +import ( + "testing" + + "github.com/crewjam/saml/samlsp" + "github.com/golang-jwt/jwt/v4" + "github.com/stretchr/testify/require" +) + +func TestUsernameFromSAMLNameID(t *testing.T) { + cfg := &SAMLConfig{UseNameIDAsUsername: true} + claims := samlsp.JWTSessionClaims{StandardClaims: jwt.StandardClaims{Subject: "nameid-user"}} + require.Equal(t, "nameid-user", usernameFromSAML(claims, cfg)) +} + +func TestUsernameFromSAMLAttributeMapping(t *testing.T) { + cfg := &SAMLConfig{ + UseNameIDAsUsername: false, + AttributeMapping: map[string]string{"uid": "username"}, + } + claims := samlsp.JWTSessionClaims{ + Attributes: samlsp.Attributes{ + "uid": []string{"mapped-user"}, + }, + } + require.Equal(t, "mapped-user", usernameFromSAML(claims, cfg)) +} + +func TestUsernameFromSAMLUsernameAttribute(t *testing.T) { + cfg := &SAMLConfig{ + UseNameIDAsUsername: false, + UsernameAttribute: "email", + } + claims := samlsp.JWTSessionClaims{ + Attributes: samlsp.Attributes{ + "email": []string{"user@example.com"}, + }, + } + require.Equal(t, "user@example.com", usernameFromSAML(claims, cfg)) +} + +func TestResolveSAMLPermissions(t *testing.T) { + cfg := &SAMLConfig{ + StaffGroups: []string{"staff"}, + SuperuserGroups: []string{"super"}, + CanApproveGroups: []string{"approvers"}, + } + isStaff, canApprove := resolveSAMLPermissions([]string{"approvers"}, cfg) + require.False(t, isStaff) + require.True(t, canApprove) + + isStaff, canApprove = resolveSAMLPermissions([]string{"super"}, cfg) + require.True(t, isStaff) + require.True(t, canApprove) +} diff --git a/internal/app/server.go b/internal/app/server.go new file mode 100644 index 0000000..5031ba0 --- /dev/null +++ b/internal/app/server.go @@ -0,0 +1,168 @@ +package app + +import ( + "context" + "crypt-server/internal/store" + "log" + "net/http" + "net/url" + "strings" + + "github.com/crewjam/saml/samlsp" +) + +// Version is the application version displayed in the UI. +// Set at build time via: go build -ldflags "-X crypt-server/internal/app.Version=1.0.0" +var Version = "0.0.0-dev" + +type Server struct { + store store.Store + renderer *Renderer + logger *log.Logger + sessionManager *SessionManager + csrfManager *CSRFManager + samlSP *samlsp.Middleware + samlConfig *SAMLConfig + settings Settings +} + +func NewServer(store store.Store, renderer *Renderer, logger *log.Logger, sessionManager *SessionManager, csrfManager *CSRFManager, samlSP *samlsp.Middleware, samlConfig *SAMLConfig, settings Settings) *Server { + server := &Server{ + store: store, + renderer: renderer, + logger: logger, + sessionManager: sessionManager, + csrfManager: csrfManager, + samlSP: samlSP, + samlConfig: samlConfig, + settings: settings, + } + server.startRequestCleanupJob() + return server +} + +func (s *Server) Routes() http.Handler { + mux := http.NewServeMux() + + mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("web/static")))) + mux.HandleFunc("/login/", s.handleLogin) + mux.HandleFunc("/logout/", s.handleLogout) + mux.HandleFunc("/saml/login/", s.handleSAMLLogin) + mux.HandleFunc("/saml2/login/", s.handleSAMLLogin) + if s.samlSP != nil { + mux.Handle("/saml/", s.samlSP) + mux.Handle("/saml2/", s.samlSP) + } + mux.HandleFunc("/checkin/", s.handleCheckin) + mux.HandleFunc("/verify/", s.handleVerify) + mux.HandleFunc("/", s.requireAuth(s.handleIndex)) + mux.HandleFunc("/ajax/", s.requireAuth(s.handleTableAjax)) + mux.HandleFunc("/new/computer/", s.requireAuth(s.handleNewComputer)) + mux.HandleFunc("/new/secret/", s.requireAuth(s.handleNewSecret)) + mux.HandleFunc("/info/secret/", s.requireAuth(s.handleSecretInfo)) + mux.HandleFunc("/info/", s.requireAuth(s.handleComputerInfo)) + mux.HandleFunc("/request/", s.requireAuth(s.handleRequest)) + mux.HandleFunc("/retrieve/", s.requireAuth(s.handleRetrieve)) + mux.HandleFunc("/approve/", s.requireAuth(s.handleApprove)) + mux.HandleFunc("/manage-requests/", s.requireAuth(s.handleManageRequests)) + mux.HandleFunc("/admin/users/", s.requireAuth(s.handleAdminUsers)) + mux.HandleFunc("/admin/audit/", s.requireAuth(s.handleAuditLog)) + mux.HandleFunc("/password/change/", s.requireAuth(s.handlePasswordChange)) + mux.HandleFunc("/password/reset/", s.requireAuth(s.handlePasswordReset)) + + return withTrailingSlashRedirect(s.withCSRF(s.withUser(mux))) +} + +func withTrailingSlashRedirect(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Don't redirect static files + if !strings.HasPrefix(r.URL.Path, "/static/") && r.URL.Path != "/" && !strings.HasSuffix(r.URL.Path, "/") { + http.Redirect(w, r, r.URL.Path+"/", http.StatusMovedPermanently) + return + } + next.ServeHTTP(w, r) + }) +} + +func (s *Server) requireAuth(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + user := s.currentUser(r) + if !user.IsAuthenticated { + http.Redirect(w, r, "/login/?next="+urlQueryEscape(r.URL.Path), http.StatusSeeOther) + return + } + if user.MustResetPassword && user.LocalLoginEnabled && r.URL.Path != "/password/reset/" && r.URL.Path != "/logout/" { + http.Redirect(w, r, "/password/reset/", http.StatusSeeOther) + return + } + next(w, r) + } +} + +func urlQueryEscape(value string) string { + return url.QueryEscape(value) +} + +type contextKey string + +const userContextKey contextKey = "user" + +func (s *Server) withCSRF(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if s.csrfManager == nil { + next.ServeHTTP(w, r) + return + } + if !s.isCSRFExempt(r) && isUnsafeMethod(r.Method) { + if !s.csrfManager.ValidateRequest(r) { + http.Error(w, "CSRF token missing or invalid", http.StatusForbidden) + return + } + } + if _, err := s.csrfManager.EnsureToken(w, r, s.settings.CookieSecure); err != nil { + s.logger.Printf("csrf error: %v", err) + http.Error(w, "Something went wrong", http.StatusInternalServerError) + return + } + next.ServeHTTP(w, r) + }) +} + +func isUnsafeMethod(method string) bool { + switch method { + case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete: + return true + default: + return false + } +} + +func (s *Server) isCSRFExempt(r *http.Request) bool { + switch { + case strings.HasPrefix(r.URL.Path, "/checkin/"): + return true + case strings.HasPrefix(r.URL.Path, "/verify/"): + return true + case strings.HasPrefix(r.URL.Path, "/saml/"): + return true + case strings.HasPrefix(r.URL.Path, "/saml2/"): + return true + default: + return false + } +} + +func (s *Server) withUser(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := s.loadUserFromRequest(r) + if user != nil { + ctx := contextWithUser(r.Context(), user) + r = r.WithContext(ctx) + } + next.ServeHTTP(w, r) + }) +} + +func contextWithUser(ctx context.Context, user *User) context.Context { + return context.WithValue(ctx, userContextKey, user) +} diff --git a/internal/app/server_test.go b/internal/app/server_test.go new file mode 100644 index 0000000..7454dd2 --- /dev/null +++ b/internal/app/server_test.go @@ -0,0 +1,115 @@ +package app + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsUnsafeMethod(t *testing.T) { + tests := []struct { + method string + expected bool + }{ + {http.MethodGet, false}, + {http.MethodHead, false}, + {http.MethodOptions, false}, + {http.MethodPost, true}, + {http.MethodPut, true}, + {http.MethodPatch, true}, + {http.MethodDelete, true}, + {"CUSTOM", false}, + } + + for _, tt := range tests { + t.Run(tt.method, func(t *testing.T) { + require.Equal(t, tt.expected, isUnsafeMethod(tt.method)) + }) + } +} + +func TestWithTrailingSlashRedirect(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + wrapped := withTrailingSlashRedirect(handler) + + tests := []struct { + name string + path string + expectedStatus int + expectedPath string + }{ + {"root path", "/", http.StatusOK, ""}, + {"path with trailing slash", "/foo/", http.StatusOK, ""}, + {"path without trailing slash", "/foo", http.StatusMovedPermanently, "/foo/"}, + {"nested path without slash", "/foo/bar", http.StatusMovedPermanently, "/foo/bar/"}, + {"nested path with slash", "/foo/bar/", http.StatusOK, ""}, + {"static path without slash", "/static/css", http.StatusOK, ""}, + {"static path with slash", "/static/css/", http.StatusOK, ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tt.path, nil) + rec := httptest.NewRecorder() + + wrapped.ServeHTTP(rec, req) + + require.Equal(t, tt.expectedStatus, rec.Code) + if tt.expectedPath != "" { + require.Equal(t, tt.expectedPath, rec.Header().Get("Location")) + } + }) + } +} + +func TestUrlQueryEscape(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"/foo/bar/", "%2Ffoo%2Fbar%2F"}, + {"hello world", "hello+world"}, + {"special=chars&more", "special%3Dchars%26more"}, + {"", ""}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + require.Equal(t, tt.expected, urlQueryEscape(tt.input)) + }) + } +} + +func TestServerIsCSRFExempt(t *testing.T) { + s := &Server{} + + tests := []struct { + path string + expected bool + }{ + {"/checkin/", true}, + {"/checkin/foo", true}, + {"/verify/", true}, + {"/verify/ABC123", true}, + {"/saml/", true}, + {"/saml/acs", true}, + {"/saml2/", true}, + {"/saml2/acs", true}, + {"/login/", false}, + {"/", false}, + {"/admin/users/", false}, + {"/request/", false}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, tt.path, nil) + require.Equal(t, tt.expected, s.isCSRFExempt(req)) + }) + } +} diff --git a/internal/app/session.go b/internal/app/session.go new file mode 100644 index 0000000..ba145cb --- /dev/null +++ b/internal/app/session.go @@ -0,0 +1,112 @@ +package app + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" + "net/http" + "strconv" + "strings" + "time" +) + +type SessionManager struct { + key []byte + cookieName string + ttl time.Duration +} + +func NewSessionManager(key []byte, cookieName string, ttl time.Duration) (*SessionManager, error) { + if len(key) < 32 { + return nil, errors.New("session key must be at least 32 bytes") + } + if cookieName == "" { + return nil, errors.New("session cookie name is required") + } + if ttl <= 0 { + return nil, errors.New("session ttl must be positive") + } + return &SessionManager{key: key, cookieName: cookieName, ttl: ttl}, nil +} + +func (s *SessionManager) Create(username string) (string, error) { + return s.createAt(username, time.Now()) +} + +func (s *SessionManager) createAt(username string, now time.Time) (string, error) { + if username == "" { + return "", errors.New("username is required") + } + payload := fmt.Sprintf("%s|%d", username, now.Unix()) + signature := s.sign(payload) + raw := payload + "|" + signature + return base64.RawURLEncoding.EncodeToString([]byte(raw)), nil +} + +func (s *SessionManager) Validate(token string) (string, bool) { + return s.validateAt(token, time.Now()) +} + +func (s *SessionManager) validateAt(token string, now time.Time) (string, bool) { + if token == "" { + return "", false + } + decoded, err := base64.RawURLEncoding.DecodeString(token) + if err != nil { + return "", false + } + parts := strings.Split(string(decoded), "|") + if len(parts) != 3 { + return "", false + } + username := parts[0] + timestamp, err := strconv.ParseInt(parts[1], 10, 64) + if err != nil { + return "", false + } + payload := parts[0] + "|" + parts[1] + expected := s.sign(payload) + if !hmac.Equal([]byte(expected), []byte(parts[2])) { + return "", false + } + if now.After(time.Unix(timestamp, 0).Add(s.ttl)) { + return "", false + } + return username, true +} + +func (s *SessionManager) SetCookie(w http.ResponseWriter, value string, secure bool) { + http.SetCookie(w, &http.Cookie{ + Name: s.cookieName, + Value: value, + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + Secure: secure, + }) +} + +func (s *SessionManager) ClearCookie(w http.ResponseWriter, secure bool) { + http.SetCookie(w, &http.Cookie{ + Name: s.cookieName, + Value: "", + Path: "/", + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + MaxAge: -1, + Secure: secure, + }) +} + +func (s *SessionManager) CookieName() string { + return s.cookieName +} + +func (s *SessionManager) sign(payload string) string { + mac := hmac.New(sha256.New, s.key) + mac.Write([]byte(payload)) + return hex.EncodeToString(mac.Sum(nil)) +} diff --git a/internal/app/session_test.go b/internal/app/session_test.go new file mode 100644 index 0000000..4c9c891 --- /dev/null +++ b/internal/app/session_test.go @@ -0,0 +1,48 @@ +package app + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestSessionManagerRoundTrip(t *testing.T) { + manager, err := NewSessionManager([]byte("test-session-key-32-bytes-long!!"), "crypt_session", time.Hour) + require.NoError(t, err) + + token, err := manager.createAt("admin", time.Unix(1000, 0)) + require.NoError(t, err) + + username, ok := manager.validateAt(token, time.Unix(2000, 0)) + require.True(t, ok) + require.Equal(t, "admin", username) +} + +func TestSessionManagerExpired(t *testing.T) { + manager, err := NewSessionManager([]byte("test-session-key-32-bytes-long!!"), "crypt_session", time.Hour) + require.NoError(t, err) + + token, err := manager.createAt("admin", time.Unix(1000, 0)) + require.NoError(t, err) + + _, ok := manager.validateAt(token, time.Unix(1000+int64(2*time.Hour.Seconds()), 0)) + require.False(t, ok) +} + +func TestSessionManagerInvalidSignature(t *testing.T) { + manager, err := NewSessionManager([]byte("test-session-key-32-bytes-long!!"), "crypt_session", time.Hour) + require.NoError(t, err) + + token, err := manager.createAt("admin", time.Unix(1000, 0)) + require.NoError(t, err) + + last := token[len(token)-1] + replace := byte('a') + if last == replace { + replace = 'b' + } + tampered := token[:len(token)-1] + string(replace) + _, ok := manager.validateAt(tampered, time.Unix(1000, 0)) + require.False(t, ok) +} diff --git a/internal/app/settings.go b/internal/app/settings.go new file mode 100644 index 0000000..c047cc3 --- /dev/null +++ b/internal/app/settings.go @@ -0,0 +1,12 @@ +package app + +import "time" + +type Settings struct { + ApproveOwn bool + AllApprove bool + SessionTTL time.Duration + CookieSecure bool + RequestCleanupInterval time.Duration + RotateViewedSecrets bool +} diff --git a/internal/app/static_test.go b/internal/app/static_test.go new file mode 100644 index 0000000..4d83dd8 --- /dev/null +++ b/internal/app/static_test.go @@ -0,0 +1,99 @@ +package app + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestStaticPathsNotRedirected(t *testing.T) { + // Test the middleware directly to ensure static paths don't get redirected + called := false + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + }) + + wrapped := withTrailingSlashRedirect(handler) + + tests := []struct { + name string + path string + shouldRedirect bool + expectedLocation string + }{ + { + name: "Static CSS not redirected", + path: "/static/style.css", + shouldRedirect: false, + }, + { + name: "Static nested path not redirected", + path: "/static/bootstrap/css/bootstrap.min.css", + shouldRedirect: false, + }, + { + name: "Static JS not redirected", + path: "/static/js/app.js", + shouldRedirect: false, + }, + { + name: "Non-static path redirected", + path: "/admin/users", + shouldRedirect: true, + expectedLocation: "/admin/users/", + }, + { + name: "Login path redirected", + path: "/login", + shouldRedirect: true, + expectedLocation: "/login/", + }, + { + name: "Root path not redirected", + path: "/", + shouldRedirect: false, + }, + { + name: "Path with trailing slash not redirected", + path: "/admin/users/", + shouldRedirect: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + called = false + req := httptest.NewRequest("GET", tt.path, nil) + rec := httptest.NewRecorder() + + wrapped.ServeHTTP(rec, req) + + if tt.shouldRedirect { + require.Equal(t, http.StatusMovedPermanently, rec.Code, "Expected redirect for %s", tt.path) + require.Equal(t, tt.expectedLocation, rec.Header().Get("Location"), "Expected redirect location") + require.False(t, called, "Handler should not be called on redirect") + } else { + require.Equal(t, http.StatusOK, rec.Code, "Expected no redirect for %s", tt.path) + require.Empty(t, rec.Header().Get("Location"), "Should not have Location header") + require.True(t, called, "Handler should be called when not redirecting") + } + }) + } +} + +func TestTrailingSlashRedirectWorksForNonStatic(t *testing.T) { + server, _, _ := newTestServer(t) + handler := server.Routes() + + // Test that non-static paths still get redirected + req := httptest.NewRequest("GET", "/admin/users", nil) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + require.Equal(t, http.StatusMovedPermanently, rec.Code, "Non-static paths should redirect to add trailing slash") + require.Equal(t, "/admin/users/", rec.Header().Get("Location"), "Should redirect to path with trailing slash") +} diff --git a/internal/app/version_test.go b/internal/app/version_test.go new file mode 100644 index 0000000..174bf23 --- /dev/null +++ b/internal/app/version_test.go @@ -0,0 +1,61 @@ +package app + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestVersionDisplayedOnAllPages(t *testing.T) { + server, _, sessionManager := newTestServer(t) + + // Set a test version + originalVersion := Version + Version = "test-version-1.2.3" + defer func() { Version = originalVersion }() + + testCases := []struct { + name string + path string + requireAuth bool + }{ + {"Login page", "/login/", false}, + {"Index page", "/", true}, + {"User list", "/admin/users/", true}, + {"New user", "/admin/users/new/", true}, + {"Audit log", "/admin/audit/", true}, + {"Password change", "/password/change/", true}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rec := httptest.NewRecorder() + + var req *http.Request + if tc.requireAuth { + req = newAuthenticatedRequest(t, sessionManager, http.MethodGet, tc.path, nil, "admin") + serveProtected(server, rec, req, func(w http.ResponseWriter, r *http.Request) { + // Let the actual handler run through Routes() + server.Routes().ServeHTTP(w, r) + }) + } else { + req = httptest.NewRequest(http.MethodGet, tc.path, nil) + server.Routes().ServeHTTP(rec, req) + } + + require.Contains(t, rec.Body.String(), "Crypt Server version test-version-1.2.3", + "Version should be displayed on %s", tc.name) + }) + } +} + +func TestVersionVariable(t *testing.T) { + // Test that version can be set and retrieved + originalVersion := Version + defer func() { Version = originalVersion }() + + Version = "1.0.0" + require.Equal(t, "1.0.0", Version) +} diff --git a/internal/app/viewmodels.go b/internal/app/viewmodels.go new file mode 100644 index 0000000..287bb63 --- /dev/null +++ b/internal/app/viewmodels.go @@ -0,0 +1,77 @@ +package app + +import "crypt-server/internal/store" + +type User struct { + ID int + Username string + IsAuthenticated bool + IsStaff bool + CanApprove bool + LocalLoginEnabled bool + MustResetPassword bool + AuthSource string +} + +type SecretView struct { + Secret *store.Secret + Approved bool + Pending bool +} + +type SecretChar struct { + Char string + Class string +} + +type RequestView struct { + ID int + Serial string + ComputerName string + RequestingUser string + ReasonForRequest string + DateRequested string +} + +type TemplateData struct { + Title string + User User + Version string + OutstandingCount int + Computers []*store.Computer + Computer *store.Computer + Secrets []*store.Secret + SecretViews []SecretView + Secret *store.Secret + Requests []*store.Request + ManageRequests []RequestView + Request *store.Request + ErrorMessage string + CanRequest bool + RequestApproved bool + ApprovedRequestID int + RequestsForSecret []*store.Request + SecretChars []SecretChar + Users []*store.User + NewUser UserForm + AdminUser *store.User + AuditEvents []*store.AuditEvent + AuditSearch string + AuditPage int + AuditPageSize int + AuditTotal int + AuditTotalPages int + CSRFToken string + PasswordChangeRequiresCurrent bool + SAMLAvailable bool + SAMLLoginURL string +} + +type UserForm struct { + Username string + IsStaff bool + CanApprove bool + LocalLoginEnabled bool + MustResetPassword bool + AuthSource string +} diff --git a/internal/crypto/secret.go b/internal/crypto/secret.go new file mode 100644 index 0000000..1cd64f2 --- /dev/null +++ b/internal/crypto/secret.go @@ -0,0 +1,69 @@ +package crypto + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "io" +) + +type AesGcmCodec struct { + aead cipher.AEAD +} + +func NewAesGcmCodecFromBase64Key(keyBase64 string) (*AesGcmCodec, error) { + if keyBase64 == "" { + return nil, errors.New("missing encryption key") + } + key, err := base64.StdEncoding.DecodeString(keyBase64) + if err != nil { + return nil, fmt.Errorf("decode key: %w", err) + } + return NewAesGcmCodec(key) +} + +func NewAesGcmCodec(key []byte) (*AesGcmCodec, error) { + if len(key) != 32 { + return nil, fmt.Errorf("invalid key length: %d", len(key)) + } + block, err := aes.NewCipher(key) + if err != nil { + return nil, fmt.Errorf("new cipher: %w", err) + } + aead, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("new gcm: %w", err) + } + return &AesGcmCodec{aead: aead}, nil +} + +func (c *AesGcmCodec) Encrypt(plaintext string) (string, error) { + nonce := make([]byte, c.aead.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", fmt.Errorf("read nonce: %w", err) + } + ciphertext := c.aead.Seal(nil, nonce, []byte(plaintext), nil) + payload := append(nonce, ciphertext...) + return base64.StdEncoding.EncodeToString(payload), nil +} + +func (c *AesGcmCodec) Decrypt(ciphertext string) (string, error) { + payload, err := base64.StdEncoding.DecodeString(ciphertext) + if err != nil { + return "", fmt.Errorf("decode ciphertext: %w", err) + } + nonceSize := c.aead.NonceSize() + if len(payload) < nonceSize { + return "", errors.New("ciphertext too short") + } + nonce := payload[:nonceSize] + sealed := payload[nonceSize:] + plaintext, err := c.aead.Open(nil, nonce, sealed, nil) + if err != nil { + return "", fmt.Errorf("decrypt: %w", err) + } + return string(plaintext), nil +} diff --git a/internal/crypto/secret_test.go b/internal/crypto/secret_test.go new file mode 100644 index 0000000..1c1620e --- /dev/null +++ b/internal/crypto/secret_test.go @@ -0,0 +1,189 @@ +package crypto + +import ( + "encoding/base64" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewAesGcmCodecFromBase64Key(t *testing.T) { + // Valid 32-byte key + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + encoded := base64.StdEncoding.EncodeToString(key) + + codec, err := NewAesGcmCodecFromBase64Key(encoded) + require.NoError(t, err) + require.NotNil(t, codec) +} + +func TestNewAesGcmCodecFromBase64KeyEmpty(t *testing.T) { + _, err := NewAesGcmCodecFromBase64Key("") + require.Error(t, err) + require.Contains(t, err.Error(), "missing encryption key") +} + +func TestNewAesGcmCodecFromBase64KeyInvalidBase64(t *testing.T) { + _, err := NewAesGcmCodecFromBase64Key("not-valid-base64!!!") + require.Error(t, err) + require.Contains(t, err.Error(), "decode key") +} + +func TestNewAesGcmCodecFromBase64KeyWrongLength(t *testing.T) { + // 16-byte key (too short) + key := make([]byte, 16) + encoded := base64.StdEncoding.EncodeToString(key) + + _, err := NewAesGcmCodecFromBase64Key(encoded) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid key length") +} + +func TestNewAesGcmCodec(t *testing.T) { + key := make([]byte, 32) + codec, err := NewAesGcmCodec(key) + require.NoError(t, err) + require.NotNil(t, codec) +} + +func TestNewAesGcmCodecWrongLength(t *testing.T) { + tests := []struct { + name string + keyLen int + }{ + {"too short", 16}, + {"too long", 64}, + {"empty", 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + key := make([]byte, tt.keyLen) + _, err := NewAesGcmCodec(key) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid key length") + }) + } +} + +func TestEncryptDecryptRoundTrip(t *testing.T) { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + codec, err := NewAesGcmCodec(key) + require.NoError(t, err) + + tests := []string{ + "hello world", + "", + "a", + "recovery-key-12345-ABCDE", + "unicode: 日本語 emoji: 🔐", + string(make([]byte, 1000)), // long string + } + + for _, plaintext := range tests { + t.Run(plaintext[:min(len(plaintext), 20)], func(t *testing.T) { + ciphertext, err := codec.Encrypt(plaintext) + require.NoError(t, err) + require.NotEmpty(t, ciphertext) + require.NotEqual(t, plaintext, ciphertext) + + decrypted, err := codec.Decrypt(ciphertext) + require.NoError(t, err) + require.Equal(t, plaintext, decrypted) + }) + } +} + +func TestEncryptProducesDifferentCiphertext(t *testing.T) { + key := make([]byte, 32) + codec, err := NewAesGcmCodec(key) + require.NoError(t, err) + + plaintext := "same plaintext" + ciphertext1, err := codec.Encrypt(plaintext) + require.NoError(t, err) + + ciphertext2, err := codec.Encrypt(plaintext) + require.NoError(t, err) + + // Due to random nonce, same plaintext should produce different ciphertext + require.NotEqual(t, ciphertext1, ciphertext2) + + // But both should decrypt to the same value + decrypted1, err := codec.Decrypt(ciphertext1) + require.NoError(t, err) + decrypted2, err := codec.Decrypt(ciphertext2) + require.NoError(t, err) + require.Equal(t, decrypted1, decrypted2) +} + +func TestDecryptInvalidBase64(t *testing.T) { + key := make([]byte, 32) + codec, err := NewAesGcmCodec(key) + require.NoError(t, err) + + _, err = codec.Decrypt("not-valid-base64!!!") + require.Error(t, err) + require.Contains(t, err.Error(), "decode ciphertext") +} + +func TestDecryptTooShort(t *testing.T) { + key := make([]byte, 32) + codec, err := NewAesGcmCodec(key) + require.NoError(t, err) + + // Base64 encode a very short payload (less than nonce size) + short := base64.StdEncoding.EncodeToString([]byte("abc")) + _, err = codec.Decrypt(short) + require.Error(t, err) + require.Contains(t, err.Error(), "ciphertext too short") +} + +func TestDecryptTampered(t *testing.T) { + key := make([]byte, 32) + codec, err := NewAesGcmCodec(key) + require.NoError(t, err) + + ciphertext, err := codec.Encrypt("secret data") + require.NoError(t, err) + + // Decode, tamper, re-encode + payload, err := base64.StdEncoding.DecodeString(ciphertext) + require.NoError(t, err) + payload[len(payload)-1] ^= 0xFF // flip bits in last byte + tampered := base64.StdEncoding.EncodeToString(payload) + + _, err = codec.Decrypt(tampered) + require.Error(t, err) + require.Contains(t, err.Error(), "decrypt") +} + +func TestDecryptWrongKey(t *testing.T) { + key1 := make([]byte, 32) + key2 := make([]byte, 32) + key2[0] = 1 // different key + + codec1, err := NewAesGcmCodec(key1) + require.NoError(t, err) + codec2, err := NewAesGcmCodec(key2) + require.NoError(t, err) + + ciphertext, err := codec1.Encrypt("secret") + require.NoError(t, err) + + _, err = codec2.Decrypt(ciphertext) + require.Error(t, err) +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/internal/fixture/types.go b/internal/fixture/types.go new file mode 100644 index 0000000..4348e4b --- /dev/null +++ b/internal/fixture/types.go @@ -0,0 +1,58 @@ +package fixture + +// MigrationOutput is the structure of the converted fixture JSON. +type MigrationOutput struct { + Computers []Computer `json:"computers"` + Secrets []Secret `json:"secrets"` + Requests []Request `json:"requests"` + Users []User `json:"users"` +} + +// Computer represents a computer entry from the fixture. +type Computer struct { + ID int `json:"id"` + Serial string `json:"serial"` + Username string `json:"username"` + ComputerName string `json:"computername"` + LastCheckin string `json:"last_checkin"` +} + +// Secret represents a secret entry from the fixture. +// The Secret field is already encrypted with the new key. +type Secret struct { + ID int `json:"id"` + ComputerID int `json:"computer_id"` + SecretType string `json:"secret_type"` + Secret string `json:"secret"` + DateEscrowed string `json:"date_escrowed"` + RotationRequired bool `json:"rotation_required"` +} + +// Request represents a request entry from the fixture. +type Request struct { + ID int `json:"id"` + SecretID int `json:"secret_id"` + RequestingUser string `json:"requesting_user"` + Approved *bool `json:"approved"` + AuthUser string `json:"auth_user"` + ReasonForRequest string `json:"reason_for_request"` + ReasonForApproval string `json:"reason_for_approval"` + DateRequested string `json:"date_requested"` + DateApproved string `json:"date_approved"` + Current bool `json:"current"` +} + +// User represents a user entry from the fixture. +type User struct { + ID int `json:"id"` + Username string `json:"username"` + Email string `json:"email"` + IsStaff bool `json:"is_staff"` + IsSuper bool `json:"is_superuser"` + CanApprove bool `json:"can_approve"` + Groups []string `json:"groups"` + PasswordHash string `json:"password_hash"` + MustResetPassword bool `json:"must_reset_password"` + LocalLoginEnabled bool `json:"local_login_enabled"` + AuthSource string `json:"auth_source"` +} diff --git a/internal/migrate/embedded.go b/internal/migrate/embedded.go new file mode 100644 index 0000000..8376d8c --- /dev/null +++ b/internal/migrate/embedded.go @@ -0,0 +1,6 @@ +package migrate + +import "embed" + +//go:embed migrations/*/*.sql +var EmbeddedFS embed.FS diff --git a/internal/migrate/migrate.go b/internal/migrate/migrate.go new file mode 100644 index 0000000..6fd1ce5 --- /dev/null +++ b/internal/migrate/migrate.go @@ -0,0 +1,230 @@ +package migrate + +import ( + "database/sql" + "fmt" + "io/fs" + "path/filepath" + "regexp" + "sort" + "strconv" + "strings" +) + +type Migration struct { + Version int + Name string + SQL string +} + +var migrationPattern = regexp.MustCompile(`^(\d+)_.*\.sql$`) + +func Apply(db *sql.DB, driver string, fsys fs.FS) error { + if err := ensureSchemaMigrations(db, driver); err != nil { + return err + } + migrations, err := loadMigrations(fsys) + if err != nil { + return err + } + applied, err := loadAppliedVersions(db) + if err != nil { + return err + } + for _, migration := range migrations { + if applied[migration.Version] { + continue + } + if err := applyMigration(db, driver, migration); err != nil { + return err + } + } + return nil +} + +func List(fsys fs.FS) ([]Migration, error) { + return loadMigrations(fsys) +} + +func Validate(fsys fs.FS) error { + migrations, err := loadMigrations(fsys) + if err != nil { + return err + } + if len(migrations) == 0 { + return fmt.Errorf("no migrations found") + } + versions := make(map[int]struct{}) + for _, migration := range migrations { + if strings.TrimSpace(migration.SQL) == "" { + return fmt.Errorf("migration %s is empty", migration.Name) + } + if _, exists := versions[migration.Version]; exists { + return fmt.Errorf("duplicate migration version %d", migration.Version) + } + versions[migration.Version] = struct{}{} + } + return nil +} + +func ensureSchemaMigrations(db *sql.DB, driver string) error { + stmt, err := schemaMigrationsSQL(driver) + if err != nil { + return err + } + if _, err := db.Exec(stmt); err != nil { + return fmt.Errorf("create schema_migrations: %w", err) + } + return nil +} + +func schemaMigrationsSQL(driver string) (string, error) { + switch driver { + case "postgres": + return "CREATE TABLE IF NOT EXISTS schema_migrations (version INTEGER PRIMARY KEY, applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW())", nil + case "sqlite": + return "CREATE TABLE IF NOT EXISTS schema_migrations (version INTEGER PRIMARY KEY, applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP)", nil + default: + return "", fmt.Errorf("unsupported database driver: %s", driver) + } +} + +func loadMigrations(fsys fs.FS) ([]Migration, error) { + entries, err := fs.ReadDir(fsys, ".") + if err != nil { + return nil, fmt.Errorf("read migrations: %w", err) + } + migrations := make([]Migration, 0) + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + matches := migrationPattern.FindStringSubmatch(name) + if matches == nil { + continue + } + version, err := strconv.Atoi(matches[1]) + if err != nil { + return nil, fmt.Errorf("parse migration version %s: %w", name, err) + } + data, err := fs.ReadFile(fsys, name) + if err != nil { + return nil, fmt.Errorf("read migration %s: %w", name, err) + } + migrations = append(migrations, Migration{ + Version: version, + Name: name, + SQL: string(data), + }) + } + sort.Slice(migrations, func(i, j int) bool { + if migrations[i].Version == migrations[j].Version { + return migrations[i].Name < migrations[j].Name + } + return migrations[i].Version < migrations[j].Version + }) + return migrations, nil +} + +func loadAppliedVersions(db *sql.DB) (map[int]bool, error) { + rows, err := db.Query("SELECT version FROM schema_migrations ORDER BY version") + if err != nil { + return nil, fmt.Errorf("load applied migrations: %w", err) + } + defer rows.Close() + + applied := make(map[int]bool) + for rows.Next() { + var version int + if err := rows.Scan(&version); err != nil { + return nil, fmt.Errorf("scan applied migration: %w", err) + } + applied[version] = true + } + return applied, rows.Err() +} + +func applyMigration(db *sql.DB, driver string, migration Migration) error { + statements := splitStatements(migration.SQL) + if len(statements) == 0 { + return fmt.Errorf("migration %s is empty", migration.Name) + } + tx, err := db.Begin() + if err != nil { + return fmt.Errorf("begin migration %s: %w", migration.Name, err) + } + for _, statement := range statements { + if _, err := tx.Exec(statement); err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return fmt.Errorf("rollback migration %s: %v", migration.Name, rollbackErr) + } + return fmt.Errorf("apply migration %s: %w", migration.Name, err) + } + } + insertSQL, err := insertMigrationSQL(driver) + if err != nil { + _ = tx.Rollback() + return err + } + if _, err := tx.Exec(insertSQL, migration.Version); err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return fmt.Errorf("rollback migration %s: %v", migration.Name, rollbackErr) + } + return fmt.Errorf("record migration %s: %w", migration.Name, err) + } + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit migration %s: %w", migration.Name, err) + } + return nil +} + +func insertMigrationSQL(driver string) (string, error) { + switch driver { + case "postgres": + return "INSERT INTO schema_migrations (version) VALUES ($1)", nil + case "sqlite": + return "INSERT INTO schema_migrations (version) VALUES (?)", nil + default: + return "", fmt.Errorf("unsupported database driver: %s", driver) + } +} + +func splitStatements(sqlText string) []string { + trimmed := strings.TrimSpace(sqlText) + if trimmed == "" { + return nil + } + statements := make([]string, 0) + var current strings.Builder + inSingle := false + inDouble := false + var prev rune + for _, ch := range sqlText { + if ch == '\'' && !inDouble && prev != '\\' { + inSingle = !inSingle + } else if ch == '"' && !inSingle && prev != '\\' { + inDouble = !inDouble + } + if ch == ';' && !inSingle && !inDouble { + statement := strings.TrimSpace(current.String()) + if statement != "" { + statements = append(statements, statement) + } + current.Reset() + prev = ch + continue + } + current.WriteRune(ch) + prev = ch + } + statement := strings.TrimSpace(current.String()) + if statement != "" { + statements = append(statements, statement) + } + return statements +} + +func SubMigrationsFS(fsys fs.FS, driver string) (fs.FS, error) { + return fs.Sub(fsys, filepath.Join("migrations", driver)) +} diff --git a/internal/migrate/migrate_test.go b/internal/migrate/migrate_test.go new file mode 100644 index 0000000..b953316 --- /dev/null +++ b/internal/migrate/migrate_test.go @@ -0,0 +1,122 @@ +package migrate + +import ( + "regexp" + "testing" + "testing/fstest" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/require" +) + +func TestSplitStatements(t *testing.T) { + input := "CREATE TABLE a (name TEXT);INSERT INTO a (name) VALUES ('x;y');" + statements := splitStatements(input) + require.Len(t, statements, 2) + require.Equal(t, "CREATE TABLE a (name TEXT)", statements[0]) + require.Equal(t, "INSERT INTO a (name) VALUES ('x;y')", statements[1]) +} + +func TestLoadMigrationsOrdersByVersion(t *testing.T) { + fs := fstest.MapFS{ + "002_add.sql": {Data: []byte("CREATE TABLE b (id INTEGER);")}, + "001_init.sql": {Data: []byte("CREATE TABLE a (id INTEGER);")}, + "README.txt": {Data: []byte("skip")}, + } + migrations, err := loadMigrations(fs) + require.NoError(t, err) + require.Len(t, migrations, 2) + require.Equal(t, 1, migrations[0].Version) + require.Equal(t, 2, migrations[1].Version) +} + +func TestSubMigrationsFS(t *testing.T) { + fs := fstest.MapFS{ + "migrations/postgres/001_init.sql": {Data: []byte("CREATE TABLE a (id INTEGER);")}, + } + sub, err := SubMigrationsFS(fs, "postgres") + require.NoError(t, err) + + migrations, err := loadMigrations(sub) + require.NoError(t, err) + require.Len(t, migrations, 1) +} + +func TestValidateRejectsEmpty(t *testing.T) { + fs := fstest.MapFS{ + "001_init.sql": {Data: []byte(" ")}, + } + err := Validate(fs) + require.Error(t, err) +} + +func TestValidateRejectsDuplicateVersions(t *testing.T) { + fs := fstest.MapFS{ + "001_init.sql": {Data: []byte("CREATE TABLE a (id INTEGER);")}, + "001_duplicate.sql": {Data: []byte("CREATE TABLE b (id INTEGER);")}, + } + err := Validate(fs) + require.Error(t, err) +} + +func TestValidateAcceptsMigrations(t *testing.T) { + fs := fstest.MapFS{ + "001_init.sql": {Data: []byte("CREATE TABLE a (id INTEGER);")}, + "002_next.sql": {Data: []byte("CREATE TABLE b (id INTEGER);")}, + } + err := Validate(fs) + require.NoError(t, err) +} + +func TestApplySkipsAppliedMigrations(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + mock.ExpectExec(regexp.QuoteMeta( + "CREATE TABLE IF NOT EXISTS schema_migrations (version INTEGER PRIMARY KEY, applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW())", + )).WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT version FROM schema_migrations ORDER BY version", + )).WillReturnRows(sqlmock.NewRows([]string{"version"}).AddRow(1)) + + fs := fstest.MapFS{ + "001_init.sql": {Data: []byte("CREATE TABLE a (id INTEGER)")}, + "002_add.sql": {Data: []byte("CREATE TABLE b (id INTEGER)")}, + } + + mock.ExpectBegin() + mock.ExpectExec(regexp.QuoteMeta("CREATE TABLE b (id INTEGER)")).WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec(regexp.QuoteMeta("INSERT INTO schema_migrations (version) VALUES ($1)")).WithArgs(2).WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + err = Apply(db, "postgres", fs) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyUsesSQLitePlaceholders(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + mock.ExpectExec(regexp.QuoteMeta( + "CREATE TABLE IF NOT EXISTS schema_migrations (version INTEGER PRIMARY KEY, applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP)", + )).WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT version FROM schema_migrations ORDER BY version", + )).WillReturnRows(sqlmock.NewRows([]string{"version"})) + + fs := fstest.MapFS{ + "001_init.sql": {Data: []byte("CREATE TABLE a (id INTEGER)")}, + } + + mock.ExpectBegin() + mock.ExpectExec(regexp.QuoteMeta("CREATE TABLE a (id INTEGER)")).WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec(regexp.QuoteMeta("INSERT INTO schema_migrations (version) VALUES (?)")).WithArgs(1).WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + err = Apply(db, "sqlite", fs) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} diff --git a/internal/migrate/migrations/postgres/001_init.sql b/internal/migrate/migrations/postgres/001_init.sql new file mode 100644 index 0000000..8b217de --- /dev/null +++ b/internal/migrate/migrations/postgres/001_init.sql @@ -0,0 +1,40 @@ +CREATE TABLE IF NOT EXISTS computers ( + id SERIAL PRIMARY KEY, + serial TEXT NOT NULL UNIQUE, + username TEXT NOT NULL, + computername TEXT NOT NULL, + last_checkin TIMESTAMPTZ +); + +CREATE TABLE IF NOT EXISTS secrets ( + id SERIAL PRIMARY KEY, + computer_id INTEGER NOT NULL REFERENCES computers(id) ON DELETE CASCADE, + secret TEXT NOT NULL, + secret_type TEXT NOT NULL, + date_escrowed TIMESTAMPTZ NOT NULL DEFAULT NOW(), + rotation_required BOOLEAN NOT NULL DEFAULT FALSE +); + +CREATE TABLE IF NOT EXISTS requests ( + id SERIAL PRIMARY KEY, + secret_id INTEGER NOT NULL REFERENCES secrets(id) ON DELETE RESTRICT, + requesting_user TEXT NOT NULL, + approved BOOLEAN NULL, + auth_user TEXT NULL, + reason_for_request TEXT NOT NULL, + reason_for_approval TEXT NULL, + date_requested TIMESTAMPTZ NOT NULL DEFAULT NOW(), + date_approved TIMESTAMPTZ NULL, + current BOOLEAN NOT NULL DEFAULT TRUE +); + +CREATE TABLE IF NOT EXISTS users ( + id SERIAL PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + password_hash TEXT NULL, + is_staff BOOLEAN NOT NULL DEFAULT FALSE, + can_approve BOOLEAN NOT NULL DEFAULT FALSE, + local_login_enabled BOOLEAN NOT NULL DEFAULT FALSE, + must_reset_password BOOLEAN NOT NULL DEFAULT FALSE, + auth_source TEXT NOT NULL DEFAULT 'local' +); diff --git a/internal/migrate/migrations/postgres/002_audit_events.sql b/internal/migrate/migrations/postgres/002_audit_events.sql new file mode 100644 index 0000000..37426d3 --- /dev/null +++ b/internal/migrate/migrations/postgres/002_audit_events.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS audit_events ( + id SERIAL PRIMARY KEY, + actor TEXT NOT NULL, + target_user TEXT NOT NULL, + action TEXT NOT NULL, + reason TEXT NULL, + ip_address TEXT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); diff --git a/internal/migrate/migrations/sqlite/001_init.sql b/internal/migrate/migrations/sqlite/001_init.sql new file mode 100644 index 0000000..02a044e --- /dev/null +++ b/internal/migrate/migrations/sqlite/001_init.sql @@ -0,0 +1,40 @@ +CREATE TABLE IF NOT EXISTS computers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + serial TEXT NOT NULL UNIQUE, + username TEXT NOT NULL, + computername TEXT NOT NULL, + last_checkin TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS secrets ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + computer_id INTEGER NOT NULL REFERENCES computers(id) ON DELETE CASCADE, + secret TEXT NOT NULL, + secret_type TEXT NOT NULL, + date_escrowed TIMESTAMP NOT NULL, + rotation_required BOOLEAN NOT NULL DEFAULT 0 +); + +CREATE TABLE IF NOT EXISTS requests ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + secret_id INTEGER NOT NULL REFERENCES secrets(id) ON DELETE RESTRICT, + requesting_user TEXT NOT NULL, + approved BOOLEAN NULL, + auth_user TEXT NULL, + reason_for_request TEXT NOT NULL, + reason_for_approval TEXT NULL, + date_requested TIMESTAMP NOT NULL, + date_approved TIMESTAMP NULL, + current BOOLEAN NOT NULL DEFAULT 1 +); + +CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT NOT NULL UNIQUE, + password_hash TEXT NULL, + is_staff BOOLEAN NOT NULL DEFAULT 0, + can_approve BOOLEAN NOT NULL DEFAULT 0, + local_login_enabled BOOLEAN NOT NULL DEFAULT 0, + must_reset_password BOOLEAN NOT NULL DEFAULT 0, + auth_source TEXT NOT NULL DEFAULT 'local' +); diff --git a/internal/migrate/migrations/sqlite/002_audit_events.sql b/internal/migrate/migrations/sqlite/002_audit_events.sql new file mode 100644 index 0000000..2a94842 --- /dev/null +++ b/internal/migrate/migrations/sqlite/002_audit_events.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS audit_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + actor TEXT NOT NULL, + target_user TEXT NOT NULL, + action TEXT NOT NULL, + reason TEXT NULL, + ip_address TEXT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); diff --git a/internal/store/crypto_test.go b/internal/store/crypto_test.go new file mode 100644 index 0000000..3a90f16 --- /dev/null +++ b/internal/store/crypto_test.go @@ -0,0 +1,28 @@ +package store + +import ( + "encoding/base64" + "testing" + + "crypt-server/internal/crypto" + "github.com/stretchr/testify/require" +) + +func TestAesGcmCodecRoundTrip(t *testing.T) { + key := make([]byte, 32) + for i := range key { + key[i] = byte(i + 1) + } + encoded := base64.StdEncoding.EncodeToString(key) + codec, err := crypto.NewAesGcmCodecFromBase64Key(encoded) + require.NoError(t, err) + + ciphertext, err := codec.Encrypt("secret") + require.NoError(t, err) + require.NotEmpty(t, ciphertext) + require.NotEqual(t, "secret", ciphertext) + + plaintext, err := codec.Decrypt(ciphertext) + require.NoError(t, err) + require.Equal(t, "secret", plaintext) +} diff --git a/internal/store/logging.go b/internal/store/logging.go new file mode 100644 index 0000000..b8c06ad --- /dev/null +++ b/internal/store/logging.go @@ -0,0 +1,255 @@ +package store + +import ( + "log" + "time" +) + +// LoggingStore wraps a Store and logs all operations. +type LoggingStore struct { + store Store + logger *log.Logger +} + +// NewLoggingStore creates a new LoggingStore that wraps the given store. +func NewLoggingStore(store Store, logger *log.Logger) *LoggingStore { + return &LoggingStore{store: store, logger: logger} +} + +func (s *LoggingStore) AddComputer(serial, username, computerName string) (*Computer, error) { + result, err := s.store.AddComputer(serial, username, computerName) + if err != nil { + s.logger.Printf("db: AddComputer failed: serial=%s error=%v", serial, err) + } else { + s.logger.Printf("db: AddComputer: serial=%s username=%s computername=%s", serial, username, computerName) + } + return result, err +} + +func (s *LoggingStore) UpsertComputer(serial, username, computerName string, lastCheckin time.Time) (*Computer, error) { + result, err := s.store.UpsertComputer(serial, username, computerName, lastCheckin) + if err != nil { + s.logger.Printf("db: UpsertComputer failed: serial=%s error=%v", serial, err) + } else { + s.logger.Printf("db: UpsertComputer: serial=%s username=%s computername=%s", serial, username, computerName) + } + return result, err +} + +func (s *LoggingStore) ListComputers() ([]*Computer, error) { + return s.store.ListComputers() +} + +func (s *LoggingStore) GetComputerByID(id int) (*Computer, error) { + return s.store.GetComputerByID(id) +} + +func (s *LoggingStore) GetComputerBySerial(serial string) (*Computer, error) { + return s.store.GetComputerBySerial(serial) +} + +func (s *LoggingStore) AddSecret(computerID int, secretType, secret string, rotationRequired bool) (*Secret, bool, error) { + result, isNew, err := s.store.AddSecret(computerID, secretType, secret, rotationRequired) + if err != nil { + s.logger.Printf("db: AddSecret failed: computer_id=%d type=%s error=%v", computerID, secretType, err) + } else if isNew { + s.logger.Printf("db: AddSecret: computer_id=%d type=%s (new)", computerID, secretType) + } else { + s.logger.Printf("db: AddSecret: computer_id=%d type=%s (updated)", computerID, secretType) + } + return result, isNew, err +} + +func (s *LoggingStore) ListSecretsByComputer(computerID int) ([]*Secret, error) { + return s.store.ListSecretsByComputer(computerID) +} + +func (s *LoggingStore) GetSecretByID(id int) (*Secret, error) { + return s.store.GetSecretByID(id) +} + +func (s *LoggingStore) GetLatestSecretByComputerAndType(computerID int, secretType string) (*Secret, error) { + return s.store.GetLatestSecretByComputerAndType(computerID, secretType) +} + +func (s *LoggingStore) AddRequest(secretID int, requestingUser, reason string, approvedBy string, approved *bool) (*Request, error) { + result, err := s.store.AddRequest(secretID, requestingUser, reason, approvedBy, approved) + if err != nil { + s.logger.Printf("db: AddRequest failed: secret_id=%d user=%s error=%v", secretID, requestingUser, err) + } else { + s.logger.Printf("db: AddRequest: secret_id=%d user=%s", secretID, requestingUser) + } + return result, err +} + +func (s *LoggingStore) ListRequestsBySecret(secretID int) ([]*Request, error) { + return s.store.ListRequestsBySecret(secretID) +} + +func (s *LoggingStore) ListOutstandingRequests() ([]*Request, error) { + return s.store.ListOutstandingRequests() +} + +func (s *LoggingStore) GetRequestByID(id int) (*Request, error) { + return s.store.GetRequestByID(id) +} + +func (s *LoggingStore) ApproveRequest(requestID int, approved bool, reason, approver string) (*Request, error) { + result, err := s.store.ApproveRequest(requestID, approved, reason, approver) + if err != nil { + s.logger.Printf("db: ApproveRequest failed: request_id=%d error=%v", requestID, err) + } else { + s.logger.Printf("db: ApproveRequest: request_id=%d approved=%t approver=%s", requestID, approved, approver) + } + return result, err +} + +func (s *LoggingStore) AddUser(username, passwordHash string, isStaff, canApprove, localLoginEnabled, mustResetPassword bool, authSource string) (*User, error) { + result, err := s.store.AddUser(username, passwordHash, isStaff, canApprove, localLoginEnabled, mustResetPassword, authSource) + if err != nil { + s.logger.Printf("db: AddUser failed: username=%s error=%v", username, err) + } else { + s.logger.Printf("db: AddUser: username=%s is_staff=%t can_approve=%t auth_source=%s", username, isStaff, canApprove, authSource) + } + return result, err +} + +func (s *LoggingStore) GetUserByUsername(username string) (*User, error) { + return s.store.GetUserByUsername(username) +} + +func (s *LoggingStore) ListUsers() ([]*User, error) { + return s.store.ListUsers() +} + +func (s *LoggingStore) GetUserByID(id int) (*User, error) { + return s.store.GetUserByID(id) +} + +func (s *LoggingStore) UpdateUser(id int, username string, isStaff, canApprove, localLoginEnabled, mustResetPassword bool, authSource string) (*User, error) { + result, err := s.store.UpdateUser(id, username, isStaff, canApprove, localLoginEnabled, mustResetPassword, authSource) + if err != nil { + s.logger.Printf("db: UpdateUser failed: id=%d username=%s error=%v", id, username, err) + } else { + s.logger.Printf("db: UpdateUser: id=%d username=%s is_staff=%t can_approve=%t", id, username, isStaff, canApprove) + } + return result, err +} + +func (s *LoggingStore) UpdateUserPassword(id int, passwordHash string, mustResetPassword bool) (*User, error) { + result, err := s.store.UpdateUserPassword(id, passwordHash, mustResetPassword) + if err != nil { + s.logger.Printf("db: UpdateUserPassword failed: id=%d error=%v", id, err) + } else { + s.logger.Printf("db: UpdateUserPassword: id=%d username=%s", id, result.Username) + } + return result, err +} + +func (s *LoggingStore) DeleteUser(id int) error { + err := s.store.DeleteUser(id) + if err != nil { + s.logger.Printf("db: DeleteUser failed: id=%d error=%v", id, err) + } else { + s.logger.Printf("db: DeleteUser: id=%d", id) + } + return err +} + +func (s *LoggingStore) CleanupRequests(approvedBefore time.Time) (int, error) { + count, err := s.store.CleanupRequests(approvedBefore) + if err != nil { + s.logger.Printf("db: CleanupRequests failed: error=%v", err) + } else if count > 0 { + s.logger.Printf("db: CleanupRequests: cleaned=%d", count) + } + return count, err +} + +func (s *LoggingStore) SetSecretRotationRequired(secretID int, rotationRequired bool) (*Secret, error) { + result, err := s.store.SetSecretRotationRequired(secretID, rotationRequired) + if err != nil { + s.logger.Printf("db: SetSecretRotationRequired failed: secret_id=%d error=%v", secretID, err) + } else { + s.logger.Printf("db: SetSecretRotationRequired: secret_id=%d rotation_required=%t", secretID, rotationRequired) + } + return result, err +} + +func (s *LoggingStore) AddAuditEvent(actor, targetUser, action, reason, ipAddress string) (*AuditEvent, error) { + result, err := s.store.AddAuditEvent(actor, targetUser, action, reason, ipAddress) + if err != nil { + s.logger.Printf("db: AddAuditEvent failed: actor=%s action=%s error=%v", actor, action, err) + } else { + s.logger.Printf("db: AddAuditEvent: actor=%s target=%s action=%s", actor, targetUser, action) + } + return result, err +} + +func (s *LoggingStore) ListAuditEvents() ([]*AuditEvent, error) { + return s.store.ListAuditEvents() +} + +func (s *LoggingStore) SearchAuditEvents(query string) ([]*AuditEvent, error) { + return s.store.SearchAuditEvents(query) +} + +func (s *LoggingStore) ListAuditEventsPaged(limit, offset int) ([]*AuditEvent, error) { + return s.store.ListAuditEventsPaged(limit, offset) +} + +func (s *LoggingStore) SearchAuditEventsPaged(query string, limit, offset int) ([]*AuditEvent, error) { + return s.store.SearchAuditEventsPaged(query, limit, offset) +} + +func (s *LoggingStore) CountAuditEvents() (int, error) { + return s.store.CountAuditEvents() +} + +func (s *LoggingStore) CountSearchAuditEvents(query string) (int, error) { + return s.store.CountSearchAuditEvents(query) +} + +func (s *LoggingStore) IsEmpty() (bool, error) { + return s.store.IsEmpty() +} + +func (s *LoggingStore) ImportComputer(id int, serial, username, computerName string, lastCheckin time.Time) error { + err := s.store.ImportComputer(id, serial, username, computerName, lastCheckin) + if err != nil { + s.logger.Printf("db: ImportComputer failed: id=%d serial=%s error=%v", id, serial, err) + } else { + s.logger.Printf("db: ImportComputer: id=%d serial=%s", id, serial) + } + return err +} + +func (s *LoggingStore) ImportSecret(id, computerID int, secretType, encryptedSecret string, dateEscrowed time.Time, rotationRequired bool) error { + err := s.store.ImportSecret(id, computerID, secretType, encryptedSecret, dateEscrowed, rotationRequired) + if err != nil { + s.logger.Printf("db: ImportSecret failed: id=%d computer_id=%d error=%v", id, computerID, err) + } else { + s.logger.Printf("db: ImportSecret: id=%d computer_id=%d type=%s", id, computerID, secretType) + } + return err +} + +func (s *LoggingStore) ImportRequest(id, secretID int, requestingUser string, approved *bool, authUser, reasonForRequest, reasonForApproval string, dateRequested time.Time, dateApproved *time.Time, current bool) error { + err := s.store.ImportRequest(id, secretID, requestingUser, approved, authUser, reasonForRequest, reasonForApproval, dateRequested, dateApproved, current) + if err != nil { + s.logger.Printf("db: ImportRequest failed: id=%d secret_id=%d error=%v", id, secretID, err) + } else { + s.logger.Printf("db: ImportRequest: id=%d secret_id=%d user=%s", id, secretID, requestingUser) + } + return err +} + +func (s *LoggingStore) ImportUser(id int, username, passwordHash string, isStaff, canApprove, localLoginEnabled, mustResetPassword bool, authSource string) error { + err := s.store.ImportUser(id, username, passwordHash, isStaff, canApprove, localLoginEnabled, mustResetPassword, authSource) + if err != nil { + s.logger.Printf("db: ImportUser failed: id=%d username=%s error=%v", id, username, err) + } else { + s.logger.Printf("db: ImportUser: id=%d username=%s", id, username) + } + return err +} diff --git a/internal/store/logging_test.go b/internal/store/logging_test.go new file mode 100644 index 0000000..629d69b --- /dev/null +++ b/internal/store/logging_test.go @@ -0,0 +1,194 @@ +package store + +import ( + "bytes" + "log" + "path/filepath" + "testing" + "time" + + "crypt-server/internal/migrate" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/require" +) + +func newTestLoggingStore(t *testing.T) (*LoggingStore, *bytes.Buffer) { + t.Helper() + codec := testCodec(t) + path := filepath.Join(t.TempDir(), "crypt.db") + sqliteStore, err := NewSQLiteStore(path, codec) + require.NoError(t, err) + + // Apply real migrations + sqliteFS, err := migrate.SubMigrationsFS(migrate.EmbeddedFS, "sqlite") + require.NoError(t, err) + err = migrate.Apply(sqliteStore.DB(), "sqlite", sqliteFS) + require.NoError(t, err) + + var buf bytes.Buffer + logger := log.New(&buf, "", 0) + return NewLoggingStore(sqliteStore, logger), &buf +} + +func TestLoggingStoreLogsWrites(t *testing.T) { + loggingStore, buf := newTestLoggingStore(t) + + // AddComputer should log + _, err := loggingStore.AddComputer("ABC123", "testuser", "Test-Mac") + require.NoError(t, err) + require.Contains(t, buf.String(), "db: AddComputer") + require.Contains(t, buf.String(), "serial=ABC123") + buf.Reset() + + // UpsertComputer should log + _, err = loggingStore.UpsertComputer("DEF456", "testuser", "Test-Mac", time.Now()) + require.NoError(t, err) + require.Contains(t, buf.String(), "db: UpsertComputer") + buf.Reset() + + // AddUser should log + user, err := loggingStore.AddUser("admin", "hash", true, true, true, false, "local") + require.NoError(t, err) + require.Contains(t, buf.String(), "db: AddUser") + require.Contains(t, buf.String(), "username=admin") + buf.Reset() + + // UpdateUser should log + _, err = loggingStore.UpdateUser(user.ID, "admin", true, true, true, false, "local") + require.NoError(t, err) + require.Contains(t, buf.String(), "db: UpdateUser") + buf.Reset() + + // DeleteUser should log + err = loggingStore.DeleteUser(user.ID) + require.NoError(t, err) + require.Contains(t, buf.String(), "db: DeleteUser") + buf.Reset() +} + +func TestLoggingStoreDoesNotLogReads(t *testing.T) { + loggingStore, buf := newTestLoggingStore(t) + + // Add some data first + _, _ = loggingStore.AddComputer("ABC123", "testuser", "Test-Mac") + _, _ = loggingStore.AddUser("admin", "hash", true, true, true, false, "local") + buf.Reset() + + // ListComputers should not log + _, _ = loggingStore.ListComputers() + require.Empty(t, buf.String()) + + // GetComputerByID should not log + _, _ = loggingStore.GetComputerByID(1) + require.Empty(t, buf.String()) + + // GetComputerBySerial should not log + _, _ = loggingStore.GetComputerBySerial("ABC123") + require.Empty(t, buf.String()) + + // GetUserByUsername should not log + _, _ = loggingStore.GetUserByUsername("admin") + require.Empty(t, buf.String()) + + // ListUsers should not log + _, _ = loggingStore.ListUsers() + require.Empty(t, buf.String()) + + // GetUserByID should not log + _, _ = loggingStore.GetUserByID(1) + require.Empty(t, buf.String()) +} + +func TestLoggingStoreLogsSecretOperations(t *testing.T) { + loggingStore, buf := newTestLoggingStore(t) + + // Setup + computer, _ := loggingStore.AddComputer("ABC123", "testuser", "Test-Mac") + buf.Reset() + + // AddSecret (new) should log + secret, isNew, err := loggingStore.AddSecret(computer.ID, "recovery_key", "SECRET123", false) + require.NoError(t, err) + require.True(t, isNew) + require.Contains(t, buf.String(), "db: AddSecret") + require.Contains(t, buf.String(), "(new)") + buf.Reset() + + // AddSecret with same value should log as "(updated)" (duplicate detection) + _, isNew, err = loggingStore.AddSecret(computer.ID, "recovery_key", "SECRET123", false) + require.NoError(t, err) + require.False(t, isNew) + require.Contains(t, buf.String(), "db: AddSecret") + require.Contains(t, buf.String(), "(updated)") + buf.Reset() + + // SetSecretRotationRequired should log + _, err = loggingStore.SetSecretRotationRequired(secret.ID, true) + require.NoError(t, err) + require.Contains(t, buf.String(), "db: SetSecretRotationRequired") + buf.Reset() + + // GetSecretByID should not log + _, _ = loggingStore.GetSecretByID(secret.ID) + require.Empty(t, buf.String()) + + // ListSecretsByComputer should not log + _, _ = loggingStore.ListSecretsByComputer(computer.ID) + require.Empty(t, buf.String()) +} + +func TestLoggingStoreLogsRequestOperations(t *testing.T) { + loggingStore, buf := newTestLoggingStore(t) + + // Setup + computer, _ := loggingStore.AddComputer("ABC123", "testuser", "Test-Mac") + secret, _, _ := loggingStore.AddSecret(computer.ID, "recovery_key", "SECRET123", false) + buf.Reset() + + // AddRequest should log + request, err := loggingStore.AddRequest(secret.ID, "requester", "need key", "", nil) + require.NoError(t, err) + require.Contains(t, buf.String(), "db: AddRequest") + require.Contains(t, buf.String(), "user=requester") + buf.Reset() + + // ApproveRequest should log + _, err = loggingStore.ApproveRequest(request.ID, true, "approved", "approver") + require.NoError(t, err) + require.Contains(t, buf.String(), "db: ApproveRequest") + require.Contains(t, buf.String(), "approved=true") + buf.Reset() + + // GetRequestByID should not log + _, _ = loggingStore.GetRequestByID(request.ID) + require.Empty(t, buf.String()) + + // ListRequestsBySecret should not log + _, _ = loggingStore.ListRequestsBySecret(secret.ID) + require.Empty(t, buf.String()) + + // ListOutstandingRequests should not log + _, _ = loggingStore.ListOutstandingRequests() + require.Empty(t, buf.String()) +} + +func TestLoggingStoreLogsAuditEvents(t *testing.T) { + loggingStore, buf := newTestLoggingStore(t) + + // AddAuditEvent should log + _, err := loggingStore.AddAuditEvent("admin", "user1", "user_created", "", "127.0.0.1") + require.NoError(t, err) + require.Contains(t, buf.String(), "db: AddAuditEvent") + require.Contains(t, buf.String(), "actor=admin") + require.Contains(t, buf.String(), "action=user_created") + buf.Reset() + + // ListAuditEvents should not log + _, _ = loggingStore.ListAuditEvents() + require.Empty(t, buf.String()) + + // SearchAuditEvents should not log + _, _ = loggingStore.SearchAuditEvents("admin") + require.Empty(t, buf.String()) +} diff --git a/internal/store/models.go b/internal/store/models.go new file mode 100644 index 0000000..8f753a8 --- /dev/null +++ b/internal/store/models.go @@ -0,0 +1,99 @@ +package store + +import "time" + +// DateTimeFormat is the standard format for displaying dates (matches Django's Y-m-d H:i:s). +const DateTimeFormat = "2006-01-02 15:04:05" + +type Computer struct { + ID int + Serial string + Username string + ComputerName string + LastCheckin time.Time +} + +// LastCheckinFormatted returns the last checkin time in display format. +func (c Computer) LastCheckinFormatted() string { + return c.LastCheckin.Format(DateTimeFormat) +} + +type Secret struct { + ID int + ComputerID int + SecretType string + Secret string + DateEscrowed time.Time + RotationRequired bool +} + +// SecretTypeDisplay returns the human-readable display name for the secret type. +func (s Secret) SecretTypeDisplay() string { + switch s.SecretType { + case "recovery_key": + return "Recovery Key" + case "password": + return "Password" + case "unlock_pin": + return "Unlock PIN" + default: + return s.SecretType + } +} + +// DateEscrowedFormatted returns the escrow date in display format. +func (s Secret) DateEscrowedFormatted() string { + return s.DateEscrowed.Format(DateTimeFormat) +} + +type Request struct { + ID int + SecretID int + RequestingUser string + Approved *bool + AuthUser string + ReasonForRequest string + ReasonForApproval string + DateRequested time.Time + DateApproved *time.Time + Current bool +} + +// DateRequestedFormatted returns the request date in display format. +func (r Request) DateRequestedFormatted() string { + return r.DateRequested.Format(DateTimeFormat) +} + +// DateApprovedFormatted returns the approval date in display format, or empty string if not approved. +func (r Request) DateApprovedFormatted() string { + if r.DateApproved == nil { + return "" + } + return r.DateApproved.Format(DateTimeFormat) +} + +type User struct { + ID int + Username string + PasswordHash string + IsStaff bool + CanApprove bool + LocalLoginEnabled bool + MustResetPassword bool + AuthSource string +} + +type AuditEvent struct { + ID int + Actor string + TargetUser string + Action string + Reason string + IPAddress string + CreatedAt time.Time +} + +// CreatedAtFormatted returns the created time in display format. +func (a AuditEvent) CreatedAtFormatted() string { + return a.CreatedAt.Format(DateTimeFormat) +} diff --git a/internal/store/models_test.go b/internal/store/models_test.go new file mode 100644 index 0000000..2d94ca7 --- /dev/null +++ b/internal/store/models_test.go @@ -0,0 +1,74 @@ +package store + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestComputerLastCheckinFormatted(t *testing.T) { + c := Computer{ + LastCheckin: time.Date(2024, 6, 15, 14, 30, 45, 0, time.UTC), + } + require.Equal(t, "2024-06-15 14:30:45", c.LastCheckinFormatted()) +} + +func TestSecretSecretTypeDisplay(t *testing.T) { + tests := []struct { + secretType string + expected string + }{ + {"recovery_key", "Recovery Key"}, + {"password", "Password"}, + {"unlock_pin", "Unlock PIN"}, + {"custom_type", "custom_type"}, + {"", ""}, + } + + for _, tt := range tests { + t.Run(tt.secretType, func(t *testing.T) { + s := Secret{SecretType: tt.secretType} + require.Equal(t, tt.expected, s.SecretTypeDisplay()) + }) + } +} + +func TestSecretDateEscrowedFormatted(t *testing.T) { + s := Secret{ + DateEscrowed: time.Date(2024, 1, 20, 9, 15, 30, 0, time.UTC), + } + require.Equal(t, "2024-01-20 09:15:30", s.DateEscrowedFormatted()) +} + +func TestRequestDateRequestedFormatted(t *testing.T) { + r := Request{ + DateRequested: time.Date(2024, 3, 10, 16, 45, 0, 0, time.UTC), + } + require.Equal(t, "2024-03-10 16:45:00", r.DateRequestedFormatted()) +} + +func TestRequestDateApprovedFormatted(t *testing.T) { + t.Run("with date", func(t *testing.T) { + approvedTime := time.Date(2024, 3, 10, 17, 0, 0, 0, time.UTC) + r := Request{DateApproved: &approvedTime} + require.Equal(t, "2024-03-10 17:00:00", r.DateApprovedFormatted()) + }) + + t.Run("nil date", func(t *testing.T) { + r := Request{DateApproved: nil} + require.Equal(t, "", r.DateApprovedFormatted()) + }) +} + +func TestAuditEventCreatedAtFormatted(t *testing.T) { + a := AuditEvent{ + CreatedAt: time.Date(2024, 12, 25, 12, 0, 0, 0, time.UTC), + } + require.Equal(t, "2024-12-25 12:00:00", a.CreatedAtFormatted()) +} + +func TestDateTimeFormatConstant(t *testing.T) { + // Verify the format constant matches Django's Y-m-d H:i:s + require.Equal(t, "2006-01-02 15:04:05", DateTimeFormat) +} diff --git a/internal/store/postgres.go b/internal/store/postgres.go new file mode 100644 index 0000000..19e9dab --- /dev/null +++ b/internal/store/postgres.go @@ -0,0 +1,823 @@ +package store + +import ( + "context" + "database/sql" + "fmt" + "strings" + "time" + + _ "github.com/lib/pq" +) + +type PostgresStore struct { + db *sql.DB + codec SecretCodec +} + +func (s *PostgresStore) DB() *sql.DB { + return s.db +} + +func NewPostgresStore(dbURL string, codec SecretCodec) (*PostgresStore, error) { + db, err := sql.Open("postgres", dbURL) + if err != nil { + return nil, fmt.Errorf("open db: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := db.PingContext(ctx); err != nil { + return nil, fmt.Errorf("ping db: %w", err) + } + + return &PostgresStore{db: db, codec: codec}, nil +} + +func NewPostgresStoreWithDB(db *sql.DB, codec SecretCodec) *PostgresStore { + return &PostgresStore{db: db, codec: codec} +} + +func (s *PostgresStore) AddComputer(serial, username, computerName string) (*Computer, error) { + var id int + var lastCheckin time.Time + err := s.db.QueryRow( + `INSERT INTO computers (serial, username, computername, last_checkin) + VALUES ($1, $2, $3, NOW()) + RETURNING id, last_checkin`, + serial, username, computerName, + ).Scan(&id, &lastCheckin) + if err != nil { + return nil, fmt.Errorf("insert computer: %w", err) + } + return &Computer{ + ID: id, + Serial: serial, + Username: username, + ComputerName: computerName, + LastCheckin: lastCheckin, + }, nil +} + +func (s *PostgresStore) UpsertComputer(serial, username, computerName string, lastCheckin time.Time) (*Computer, error) { + var id int + var stored time.Time + err := s.db.QueryRow( + `INSERT INTO computers (serial, username, computername, last_checkin) + VALUES ($1, $2, $3, $4) + ON CONFLICT (serial) + DO UPDATE SET username = EXCLUDED.username, computername = EXCLUDED.computername, last_checkin = EXCLUDED.last_checkin + RETURNING id, last_checkin`, + serial, username, computerName, lastCheckin, + ).Scan(&id, &stored) + if err != nil { + return nil, fmt.Errorf("upsert computer: %w", err) + } + return &Computer{ + ID: id, + Serial: serial, + Username: username, + ComputerName: computerName, + LastCheckin: stored, + }, nil +} + +func (s *PostgresStore) ListComputers() ([]*Computer, error) { + rows, err := s.db.Query(`SELECT id, serial, username, computername, last_checkin FROM computers ORDER BY id`) + if err != nil { + return nil, fmt.Errorf("list computers: %w", err) + } + defer rows.Close() + + computers := make([]*Computer, 0) + for rows.Next() { + var computer Computer + var lastCheckin sql.NullTime + if err := rows.Scan(&computer.ID, &computer.Serial, &computer.Username, &computer.ComputerName, &lastCheckin); err != nil { + return nil, fmt.Errorf("scan computer: %w", err) + } + if lastCheckin.Valid { + computer.LastCheckin = lastCheckin.Time + } + computers = append(computers, &computer) + } + return computers, rows.Err() +} + +func (s *PostgresStore) GetComputerByID(id int) (*Computer, error) { + var computer Computer + var lastCheckin sql.NullTime + row := s.db.QueryRow(`SELECT id, serial, username, computername, last_checkin FROM computers WHERE id = $1`, id) + if err := row.Scan(&computer.ID, &computer.Serial, &computer.Username, &computer.ComputerName, &lastCheckin); err != nil { + if err == sql.ErrNoRows { + return nil, ErrNotFound + } + return nil, fmt.Errorf("get computer by id: %w", err) + } + if lastCheckin.Valid { + computer.LastCheckin = lastCheckin.Time + } + return &computer, nil +} + +func (s *PostgresStore) GetComputerBySerial(serial string) (*Computer, error) { + var computer Computer + var lastCheckin sql.NullTime + row := s.db.QueryRow(`SELECT id, serial, username, computername, last_checkin FROM computers WHERE lower(serial) = lower($1)`, serial) + if err := row.Scan(&computer.ID, &computer.Serial, &computer.Username, &computer.ComputerName, &lastCheckin); err != nil { + if err == sql.ErrNoRows { + return nil, ErrNotFound + } + return nil, fmt.Errorf("get computer by serial: %w", err) + } + if lastCheckin.Valid { + computer.LastCheckin = lastCheckin.Time + } + return &computer, nil +} + +func (s *PostgresStore) AddSecret(computerID int, secretType, secret string, rotationRequired bool) (*Secret, bool, error) { + if s.codec == nil { + return nil, false, ErrMissingCodec + } + + // Check for duplicate secret (matching Django behavior) + // If the same secret value already exists for this computer/type and rotation is not required, skip insert + if !rotationRequired { + existing, err := s.findExistingSecret(computerID, secretType, secret) + if err != nil { + return nil, false, err + } + if existing != nil { + return existing, false, nil + } + } + + encrypted, err := s.codec.Encrypt(secret) + if err != nil { + return nil, false, err + } + var id int + var dateEscrowed time.Time + row := s.db.QueryRow( + `INSERT INTO secrets (computer_id, secret_type, secret, date_escrowed, rotation_required) + VALUES ($1, $2, $3, NOW(), $4) + RETURNING id, date_escrowed`, + computerID, secretType, encrypted, rotationRequired, + ) + if err := row.Scan(&id, &dateEscrowed); err != nil { + return nil, false, fmt.Errorf("insert secret: %w", err) + } + + decrypted, err := s.decryptSecret(&Secret{ + ID: id, + ComputerID: computerID, + SecretType: secretType, + Secret: encrypted, + DateEscrowed: dateEscrowed, + RotationRequired: rotationRequired, + }) + if err != nil { + return nil, false, err + } + return decrypted, true, nil +} + +// findExistingSecret checks if the same secret value already exists for this computer/type +func (s *PostgresStore) findExistingSecret(computerID int, secretType, plainSecret string) (*Secret, error) { + rows, err := s.db.Query( + `SELECT id, computer_id, secret_type, secret, date_escrowed, rotation_required FROM secrets WHERE computer_id = $1 AND secret_type = $2`, + computerID, secretType, + ) + if err != nil { + return nil, fmt.Errorf("find existing secret: %w", err) + } + defer rows.Close() + + for rows.Next() { + var secret Secret + if err := rows.Scan(&secret.ID, &secret.ComputerID, &secret.SecretType, &secret.Secret, &secret.DateEscrowed, &secret.RotationRequired); err != nil { + return nil, fmt.Errorf("scan secret: %w", err) + } + decrypted, err := s.decryptSecret(&secret) + if err != nil { + return nil, err + } + if decrypted.Secret == plainSecret { + return decrypted, nil + } + } + return nil, nil +} + +func (s *PostgresStore) ListSecretsByComputer(computerID int) ([]*Secret, error) { + rows, err := s.db.Query(`SELECT id, computer_id, secret_type, secret, date_escrowed, rotation_required FROM secrets WHERE computer_id = $1 ORDER BY id`, computerID) + if err != nil { + return nil, fmt.Errorf("list secrets: %w", err) + } + defer rows.Close() + + secrets := make([]*Secret, 0) + for rows.Next() { + var secret Secret + if err := rows.Scan(&secret.ID, &secret.ComputerID, &secret.SecretType, &secret.Secret, &secret.DateEscrowed, &secret.RotationRequired); err != nil { + return nil, fmt.Errorf("scan secret: %w", err) + } + decrypted, err := s.decryptSecret(&secret) + if err != nil { + return nil, err + } + secrets = append(secrets, decrypted) + } + return secrets, rows.Err() +} + +func (s *PostgresStore) GetSecretByID(id int) (*Secret, error) { + var secret Secret + row := s.db.QueryRow(`SELECT id, computer_id, secret_type, secret, date_escrowed, rotation_required FROM secrets WHERE id = $1`, id) + if err := row.Scan(&secret.ID, &secret.ComputerID, &secret.SecretType, &secret.Secret, &secret.DateEscrowed, &secret.RotationRequired); err != nil { + if err == sql.ErrNoRows { + return nil, ErrNotFound + } + return nil, fmt.Errorf("get secret: %w", err) + } + return s.decryptSecret(&secret) +} + +func (s *PostgresStore) GetLatestSecretByComputerAndType(computerID int, secretType string) (*Secret, error) { + var secret Secret + row := s.db.QueryRow( + `SELECT id, computer_id, secret_type, secret, date_escrowed, rotation_required + FROM secrets + WHERE computer_id = $1 AND secret_type = $2 + ORDER BY date_escrowed DESC + LIMIT 1`, + computerID, secretType, + ) + if err := row.Scan(&secret.ID, &secret.ComputerID, &secret.SecretType, &secret.Secret, &secret.DateEscrowed, &secret.RotationRequired); err != nil { + if err == sql.ErrNoRows { + return nil, ErrNotFound + } + return nil, fmt.Errorf("get latest secret: %w", err) + } + return s.decryptSecret(&secret) +} + +func (s *PostgresStore) AddRequest(secretID int, requestingUser, reason string, approvedBy string, approved *bool) (*Request, error) { + var id int + var dateRequested time.Time + var dateApproved sql.NullTime + var approvedValue sql.NullBool + if approved != nil { + approvedValue.Valid = true + approvedValue.Bool = *approved + } + row := s.db.QueryRow( + `INSERT INTO requests (secret_id, requesting_user, approved, auth_user, reason_for_request, reason_for_approval, date_requested, date_approved, current) + VALUES ($1, $2, $3, $4, $5, $6, NOW(), CASE WHEN $3 IS NULL THEN NULL ELSE NOW() END, true) + RETURNING id, date_requested, date_approved`, + secretID, requestingUser, approvedValue, nullableString(approvedBy), reason, sql.NullString{}, + ) + if err := row.Scan(&id, &dateRequested, &dateApproved); err != nil { + return nil, fmt.Errorf("insert request: %w", err) + } + + var approvedPtr *bool + if approvedValue.Valid { + approvedPtr = &approvedValue.Bool + } + var dateApprovedPtr *time.Time + if dateApproved.Valid { + dateApprovedPtr = &dateApproved.Time + } + + return &Request{ + ID: id, + SecretID: secretID, + RequestingUser: requestingUser, + Approved: approvedPtr, + AuthUser: approvedBy, + ReasonForRequest: reason, + ReasonForApproval: "", + DateRequested: dateRequested, + DateApproved: dateApprovedPtr, + Current: true, + }, nil +} + +func (s *PostgresStore) ListRequestsBySecret(secretID int) ([]*Request, error) { + rows, err := s.db.Query(`SELECT id, secret_id, requesting_user, approved, auth_user, reason_for_request, reason_for_approval, date_requested, date_approved, current FROM requests WHERE secret_id = $1 ORDER BY id`, secretID) + if err != nil { + return nil, fmt.Errorf("list requests: %w", err) + } + defer rows.Close() + + requests := make([]*Request, 0) + for rows.Next() { + request, err := scanRequest(rows) + if err != nil { + return nil, err + } + requests = append(requests, request) + } + return requests, rows.Err() +} + +func (s *PostgresStore) ListOutstandingRequests() ([]*Request, error) { + rows, err := s.db.Query(`SELECT id, secret_id, requesting_user, approved, auth_user, reason_for_request, reason_for_approval, date_requested, date_approved, current FROM requests WHERE current = true AND approved IS NULL ORDER BY id`) + if err != nil { + return nil, fmt.Errorf("list outstanding requests: %w", err) + } + defer rows.Close() + + requests := make([]*Request, 0) + for rows.Next() { + request, err := scanRequest(rows) + if err != nil { + return nil, err + } + requests = append(requests, request) + } + return requests, rows.Err() +} + +func (s *PostgresStore) GetRequestByID(id int) (*Request, error) { + row := s.db.QueryRow(`SELECT id, secret_id, requesting_user, approved, auth_user, reason_for_request, reason_for_approval, date_requested, date_approved, current FROM requests WHERE id = $1`, id) + request, err := scanRequest(row) + if err != nil { + if err == sql.ErrNoRows { + return nil, ErrNotFound + } + return nil, err + } + return request, nil +} + +func (s *PostgresStore) ApproveRequest(requestID int, approved bool, reason, approver string) (*Request, error) { + var dateApproved time.Time + row := s.db.QueryRow( + `UPDATE requests + SET approved = $1, reason_for_approval = $2, auth_user = $3, date_approved = NOW() + WHERE id = $4 + RETURNING date_approved`, + approved, reason, approver, requestID, + ) + if err := row.Scan(&dateApproved); err != nil { + if err == sql.ErrNoRows { + return nil, ErrNotFound + } + return nil, fmt.Errorf("approve request: %w", err) + } + updated, err := s.GetRequestByID(requestID) + if err != nil { + return nil, err + } + updated.DateApproved = &dateApproved + return updated, nil +} + +func (s *PostgresStore) AddUser(username, passwordHash string, isStaff, canApprove, localLoginEnabled, mustResetPassword bool, authSource string) (*User, error) { + var id int + err := s.db.QueryRow( + `INSERT INTO users (username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source) + VALUES ($1, $2, $3, $4, $5, $6, $7) + RETURNING id`, + username, nullableString(passwordHash), isStaff, canApprove, localLoginEnabled, mustResetPassword, authSource, + ).Scan(&id) + if err != nil { + return nil, fmt.Errorf("insert user: %w", err) + } + return &User{ + ID: id, + Username: username, + PasswordHash: passwordHash, + IsStaff: isStaff, + CanApprove: canApprove, + LocalLoginEnabled: localLoginEnabled, + MustResetPassword: mustResetPassword, + AuthSource: authSource, + }, nil +} + +func (s *PostgresStore) GetUserByUsername(username string) (*User, error) { + var user User + var passwordHash sql.NullString + row := s.db.QueryRow( + `SELECT id, username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source + FROM users WHERE lower(username) = lower($1)`, + username, + ) + if err := row.Scan(&user.ID, &user.Username, &passwordHash, &user.IsStaff, &user.CanApprove, &user.LocalLoginEnabled, &user.MustResetPassword, &user.AuthSource); err != nil { + if err == sql.ErrNoRows { + return nil, ErrNotFound + } + return nil, fmt.Errorf("get user by username: %w", err) + } + if passwordHash.Valid { + user.PasswordHash = passwordHash.String + } + return &user, nil +} + +func (s *PostgresStore) ListUsers() ([]*User, error) { + rows, err := s.db.Query(`SELECT id, username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source FROM users ORDER BY id`) + if err != nil { + return nil, fmt.Errorf("list users: %w", err) + } + defer rows.Close() + + users := make([]*User, 0) + for rows.Next() { + var user User + var passwordHash sql.NullString + if err := rows.Scan(&user.ID, &user.Username, &passwordHash, &user.IsStaff, &user.CanApprove, &user.LocalLoginEnabled, &user.MustResetPassword, &user.AuthSource); err != nil { + return nil, fmt.Errorf("scan user: %w", err) + } + if passwordHash.Valid { + user.PasswordHash = passwordHash.String + } + users = append(users, &user) + } + return users, rows.Err() +} + +func (s *PostgresStore) GetUserByID(id int) (*User, error) { + var user User + var passwordHash sql.NullString + row := s.db.QueryRow( + `SELECT id, username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source + FROM users WHERE id = $1`, + id, + ) + if err := row.Scan(&user.ID, &user.Username, &passwordHash, &user.IsStaff, &user.CanApprove, &user.LocalLoginEnabled, &user.MustResetPassword, &user.AuthSource); err != nil { + if err == sql.ErrNoRows { + return nil, ErrNotFound + } + return nil, fmt.Errorf("get user by id: %w", err) + } + if passwordHash.Valid { + user.PasswordHash = passwordHash.String + } + return &user, nil +} + +func (s *PostgresStore) UpdateUser(id int, username string, isStaff, canApprove, localLoginEnabled, mustResetPassword bool, authSource string) (*User, error) { + var user User + var passwordHash sql.NullString + row := s.db.QueryRow( + `UPDATE users + SET username = $1, is_staff = $2, can_approve = $3, local_login_enabled = $4, must_reset_password = $5, auth_source = $6 + WHERE id = $7 + RETURNING id, username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source`, + username, isStaff, canApprove, localLoginEnabled, mustResetPassword, authSource, id, + ) + if err := row.Scan(&user.ID, &user.Username, &passwordHash, &user.IsStaff, &user.CanApprove, &user.LocalLoginEnabled, &user.MustResetPassword, &user.AuthSource); err != nil { + if err == sql.ErrNoRows { + return nil, ErrNotFound + } + return nil, fmt.Errorf("update user: %w", err) + } + if passwordHash.Valid { + user.PasswordHash = passwordHash.String + } + return &user, nil +} + +func (s *PostgresStore) UpdateUserPassword(id int, passwordHash string, mustResetPassword bool) (*User, error) { + var user User + var hash sql.NullString + row := s.db.QueryRow( + `UPDATE users + SET password_hash = $1, must_reset_password = $2, local_login_enabled = CASE WHEN $1 IS NULL THEN local_login_enabled ELSE true END + WHERE id = $3 + RETURNING id, username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source`, + nullableString(passwordHash), mustResetPassword, id, + ) + if err := row.Scan(&user.ID, &user.Username, &hash, &user.IsStaff, &user.CanApprove, &user.LocalLoginEnabled, &user.MustResetPassword, &user.AuthSource); err != nil { + if err == sql.ErrNoRows { + return nil, ErrNotFound + } + return nil, fmt.Errorf("update user password: %w", err) + } + if hash.Valid { + user.PasswordHash = hash.String + } + return &user, nil +} + +func (s *PostgresStore) DeleteUser(id int) error { + result, err := s.db.Exec(`DELETE FROM users WHERE id = $1`, id) + if err != nil { + return fmt.Errorf("delete user: %w", err) + } + affected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("delete user: %w", err) + } + if affected == 0 { + return ErrNotFound + } + return nil +} + +func (s *PostgresStore) CleanupRequests(approvedBefore time.Time) (int, error) { + result, err := s.db.Exec( + `UPDATE requests SET current = false WHERE current = true AND approved IS NOT NULL AND date_approved < $1`, + approvedBefore, + ) + if err != nil { + return 0, fmt.Errorf("cleanup requests: %w", err) + } + affected, err := result.RowsAffected() + if err != nil { + return 0, fmt.Errorf("cleanup requests: %w", err) + } + return int(affected), nil +} + +func (s *PostgresStore) SetSecretRotationRequired(secretID int, rotationRequired bool) (*Secret, error) { + var secret Secret + row := s.db.QueryRow( + `UPDATE secrets + SET rotation_required = $1 + WHERE id = $2 + RETURNING id, computer_id, secret_type, secret, date_escrowed, rotation_required`, + rotationRequired, secretID, + ) + if err := row.Scan(&secret.ID, &secret.ComputerID, &secret.SecretType, &secret.Secret, &secret.DateEscrowed, &secret.RotationRequired); err != nil { + if err == sql.ErrNoRows { + return nil, ErrNotFound + } + return nil, fmt.Errorf("update secret rotation: %w", err) + } + return s.decryptSecret(&secret) +} + +func (s *PostgresStore) AddAuditEvent(actor, targetUser, action, reason, ipAddress string) (*AuditEvent, error) { + var id int + var createdAt time.Time + err := s.db.QueryRow( + `INSERT INTO audit_events (actor, target_user, action, reason, ip_address) + VALUES ($1, $2, $3, $4, $5) + RETURNING id, created_at`, + actor, targetUser, action, nullableString(reason), nullableString(ipAddress), + ).Scan(&id, &createdAt) + if err != nil { + return nil, fmt.Errorf("insert audit event: %w", err) + } + return &AuditEvent{ + ID: id, + Actor: actor, + TargetUser: targetUser, + Action: action, + Reason: reason, + IPAddress: ipAddress, + CreatedAt: createdAt, + }, nil +} + +func (s *PostgresStore) ListAuditEvents() ([]*AuditEvent, error) { + rows, err := s.db.Query(`SELECT id, actor, target_user, action, reason, ip_address, created_at FROM audit_events ORDER BY created_at DESC, id DESC`) + if err != nil { + return nil, fmt.Errorf("list audit events: %w", err) + } + defer rows.Close() + + return scanAuditEvents(rows) +} + +func (s *PostgresStore) SearchAuditEvents(query string) ([]*AuditEvent, error) { + pattern := "%" + query + "%" + rows, err := s.db.Query( + `SELECT id, actor, target_user, action, reason, ip_address, created_at + FROM audit_events + WHERE actor ILIKE $1 + OR target_user ILIKE $1 + OR action ILIKE $1 + OR COALESCE(reason, '') ILIKE $1 + OR COALESCE(ip_address, '') ILIKE $1 + ORDER BY created_at DESC, id DESC`, + pattern, + ) + if err != nil { + return nil, fmt.Errorf("search audit events: %w", err) + } + defer rows.Close() + + return scanAuditEvents(rows) +} + +func (s *PostgresStore) ListAuditEventsPaged(limit, offset int) ([]*AuditEvent, error) { + rows, err := s.db.Query( + `SELECT id, actor, target_user, action, reason, ip_address, created_at + FROM audit_events + ORDER BY created_at DESC, id DESC + LIMIT $1 OFFSET $2`, + limit, offset, + ) + if err != nil { + return nil, fmt.Errorf("list audit events paged: %w", err) + } + defer rows.Close() + + return scanAuditEvents(rows) +} + +func (s *PostgresStore) SearchAuditEventsPaged(query string, limit, offset int) ([]*AuditEvent, error) { + pattern := "%" + query + "%" + rows, err := s.db.Query( + `SELECT id, actor, target_user, action, reason, ip_address, created_at + FROM audit_events + WHERE actor ILIKE $1 + OR target_user ILIKE $1 + OR action ILIKE $1 + OR COALESCE(reason, '') ILIKE $1 + OR COALESCE(ip_address, '') ILIKE $1 + ORDER BY created_at DESC, id DESC + LIMIT $2 OFFSET $3`, + pattern, limit, offset, + ) + if err != nil { + return nil, fmt.Errorf("search audit events paged: %w", err) + } + defer rows.Close() + + return scanAuditEvents(rows) +} + +func (s *PostgresStore) CountAuditEvents() (int, error) { + var count int + if err := s.db.QueryRow(`SELECT COUNT(*) FROM audit_events`).Scan(&count); err != nil { + return 0, fmt.Errorf("count audit events: %w", err) + } + return count, nil +} + +func (s *PostgresStore) CountSearchAuditEvents(query string) (int, error) { + pattern := "%" + query + "%" + var count int + if err := s.db.QueryRow( + `SELECT COUNT(*) + FROM audit_events + WHERE actor ILIKE $1 + OR target_user ILIKE $1 + OR action ILIKE $1 + OR COALESCE(reason, '') ILIKE $1 + OR COALESCE(ip_address, '') ILIKE $1`, + pattern, + ).Scan(&count); err != nil { + return 0, fmt.Errorf("count audit events search: %w", err) + } + return count, nil +} + +func (s *PostgresStore) decryptSecret(secret *Secret) (*Secret, error) { + if s.codec == nil { + return nil, ErrMissingCodec + } + plaintext, err := s.codec.Decrypt(secret.Secret) + if err != nil { + return nil, err + } + clone := *secret + clone.Secret = plaintext + return &clone, nil +} + +type scanner interface { + Scan(dest ...any) error +} + +func scanRequest(row scanner) (*Request, error) { + var request Request + var approved sql.NullBool + var authUser sql.NullString + var reasonApproval sql.NullString + var dateApproved sql.NullTime + if err := row.Scan( + &request.ID, + &request.SecretID, + &request.RequestingUser, + &approved, + &authUser, + &request.ReasonForRequest, + &reasonApproval, + &request.DateRequested, + &dateApproved, + &request.Current, + ); err != nil { + return nil, err + } + + if approved.Valid { + value := approved.Bool + request.Approved = &value + } + if authUser.Valid { + request.AuthUser = authUser.String + } + if reasonApproval.Valid { + request.ReasonForApproval = reasonApproval.String + } + if dateApproved.Valid { + request.DateApproved = &dateApproved.Time + } + + return &request, nil +} + +func scanAuditEvents(rows *sql.Rows) ([]*AuditEvent, error) { + events := make([]*AuditEvent, 0) + for rows.Next() { + var event AuditEvent + var reason sql.NullString + var ipAddress sql.NullString + if err := rows.Scan(&event.ID, &event.Actor, &event.TargetUser, &event.Action, &reason, &ipAddress, &event.CreatedAt); err != nil { + return nil, fmt.Errorf("scan audit event: %w", err) + } + if reason.Valid { + event.Reason = reason.String + } + if ipAddress.Valid { + event.IPAddress = ipAddress.String + } + events = append(events, &event) + } + return events, rows.Err() +} + +func nullableString(value string) sql.NullString { + if strings.TrimSpace(value) == "" { + return sql.NullString{} + } + return sql.NullString{String: value, Valid: true} +} + +func (s *PostgresStore) IsEmpty() (bool, error) { + tables := []string{"computers", "secrets", "requests", "users"} + for _, table := range tables { + var count int + query := fmt.Sprintf("SELECT COUNT(*) FROM %s", table) + if err := s.db.QueryRow(query).Scan(&count); err != nil { + return false, fmt.Errorf("count %s: %w", table, err) + } + if count > 0 { + return false, nil + } + } + return true, nil +} + +func (s *PostgresStore) ImportComputer(id int, serial, username, computerName string, lastCheckin time.Time) error { + _, err := s.db.Exec( + `INSERT INTO computers (id, serial, username, computername, last_checkin) VALUES ($1, $2, $3, $4, $5)`, + id, serial, username, computerName, lastCheckin, + ) + if err != nil { + return fmt.Errorf("import computer: %w", err) + } + return nil +} + +func (s *PostgresStore) ImportSecret(id, computerID int, secretType, encryptedSecret string, dateEscrowed time.Time, rotationRequired bool) error { + _, err := s.db.Exec( + `INSERT INTO secrets (id, computer_id, secret_type, secret, date_escrowed, rotation_required) VALUES ($1, $2, $3, $4, $5, $6)`, + id, computerID, secretType, encryptedSecret, dateEscrowed, rotationRequired, + ) + if err != nil { + return fmt.Errorf("import secret: %w", err) + } + return nil +} + +func (s *PostgresStore) ImportRequest(id, secretID int, requestingUser string, approved *bool, authUser, reasonForRequest, reasonForApproval string, dateRequested time.Time, dateApproved *time.Time, current bool) error { + var approvedValue sql.NullBool + if approved != nil { + approvedValue.Valid = true + approvedValue.Bool = *approved + } + var dateApprovedValue sql.NullTime + if dateApproved != nil { + dateApprovedValue.Valid = true + dateApprovedValue.Time = *dateApproved + } + _, err := s.db.Exec( + `INSERT INTO requests (id, secret_id, requesting_user, approved, auth_user, reason_for_request, reason_for_approval, date_requested, date_approved, current) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)`, + id, secretID, requestingUser, approvedValue, nullableString(authUser), reasonForRequest, nullableString(reasonForApproval), dateRequested, dateApprovedValue, current, + ) + if err != nil { + return fmt.Errorf("import request: %w", err) + } + return nil +} + +func (s *PostgresStore) ImportUser(id int, username, passwordHash string, isStaff, canApprove, localLoginEnabled, mustResetPassword bool, authSource string) error { + _, err := s.db.Exec( + `INSERT INTO users (id, username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`, + id, username, nullableString(passwordHash), isStaff, canApprove, localLoginEnabled, mustResetPassword, authSource, + ) + if err != nil { + return fmt.Errorf("import user: %w", err) + } + return nil +} diff --git a/internal/store/postgres_test.go b/internal/store/postgres_test.go new file mode 100644 index 0000000..94058dc --- /dev/null +++ b/internal/store/postgres_test.go @@ -0,0 +1,398 @@ +package store + +import ( + "regexp" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/require" +) + +func TestPostgresStoreAddComputer(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewPostgresStoreWithDB(db, testCodec(t)) + lastCheckin := time.Now() + mock.ExpectQuery(regexp.QuoteMeta( + "INSERT INTO computers (serial, username, computername, last_checkin) VALUES ($1, $2, $3, NOW()) RETURNING id, last_checkin", + )).WithArgs("SERIAL", "user", "Mac").WillReturnRows(sqlmock.NewRows([]string{"id", "last_checkin"}).AddRow(1, lastCheckin)) + + computer, err := store.AddComputer("SERIAL", "user", "Mac") + require.NoError(t, err) + require.Equal(t, 1, computer.ID) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestPostgresStoreDB(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewPostgresStoreWithDB(db, testCodec(t)) + require.NotNil(t, store.DB()) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestPostgresStoreUpsertComputer(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewPostgresStoreWithDB(db, testCodec(t)) + lastCheckin := time.Now() + mock.ExpectQuery(regexp.QuoteMeta( + "INSERT INTO computers (serial, username, computername, last_checkin) VALUES ($1, $2, $3, $4) ON CONFLICT (serial) DO UPDATE SET username = EXCLUDED.username, computername = EXCLUDED.computername, last_checkin = EXCLUDED.last_checkin RETURNING id, last_checkin", + )).WithArgs("SERIAL", "user", "Mac", lastCheckin).WillReturnRows(sqlmock.NewRows([]string{"id", "last_checkin"}).AddRow(1, lastCheckin)) + + computer, err := store.UpsertComputer("SERIAL", "user", "Mac", lastCheckin) + require.NoError(t, err) + require.Equal(t, 1, computer.ID) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestPostgresStoreListComputers(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewPostgresStoreWithDB(db, testCodec(t)) + now := time.Now() + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, serial, username, computername, last_checkin FROM computers ORDER BY id", + )).WillReturnRows(sqlmock.NewRows([]string{"id", "serial", "username", "computername", "last_checkin"}).AddRow(1, "SERIAL", "user", "Mac", now)) + + computers, err := store.ListComputers() + require.NoError(t, err) + require.Len(t, computers, 1) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestPostgresStoreGetComputerByID(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewPostgresStoreWithDB(db, testCodec(t)) + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, serial, username, computername, last_checkin FROM computers WHERE id = $1", + )).WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "serial", "username", "computername", "last_checkin"}).AddRow(1, "SERIAL", "user", "Mac", time.Now())) + + computer, err := store.GetComputerByID(1) + require.NoError(t, err) + require.Equal(t, "SERIAL", computer.Serial) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestPostgresStoreAddSecret(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewPostgresStoreWithDB(db, testCodec(t)) + now := time.Now() + // Expect the duplicate check query first (returns no rows = no duplicate) + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, computer_id, secret_type, secret, date_escrowed, rotation_required FROM secrets WHERE computer_id = $1 AND secret_type = $2", + )).WithArgs(1, "password").WillReturnRows(sqlmock.NewRows([]string{"id", "computer_id", "secret_type", "secret", "date_escrowed", "rotation_required"})) + // Then expect the insert + mock.ExpectQuery(regexp.QuoteMeta( + "INSERT INTO secrets (computer_id, secret_type, secret, date_escrowed, rotation_required) VALUES ($1, $2, $3, NOW(), $4) RETURNING id, date_escrowed", + )).WithArgs(1, "password", sqlmock.AnyArg(), false).WillReturnRows(sqlmock.NewRows([]string{"id", "date_escrowed"}).AddRow(5, now)) + + secret, isNew, err := store.AddSecret(1, "password", "secret", false) + require.NoError(t, err) + require.Equal(t, 5, secret.ID) + require.True(t, isNew) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestPostgresStoreAddRequestAndApprove(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewPostgresStoreWithDB(db, testCodec(t)) + now := time.Now() + mock.ExpectQuery(regexp.QuoteMeta( + "INSERT INTO requests (secret_id, requesting_user, approved, auth_user, reason_for_request, reason_for_approval, date_requested, date_approved, current) VALUES ($1, $2, $3, $4, $5, $6, NOW(), CASE WHEN $3 IS NULL THEN NULL ELSE NOW() END, true) RETURNING id, date_requested, date_approved", + )).WithArgs(9, "user", sqlmock.AnyArg(), sqlmock.AnyArg(), "reason", sqlmock.AnyArg()).WillReturnRows(sqlmock.NewRows([]string{"id", "date_requested", "date_approved"}).AddRow(3, now, nil)) + + request, err := store.AddRequest(9, "user", "reason", "", nil) + require.NoError(t, err) + require.Equal(t, 3, request.ID) + + mock.ExpectQuery(regexp.QuoteMeta( + "UPDATE requests SET approved = $1, reason_for_approval = $2, auth_user = $3, date_approved = NOW() WHERE id = $4 RETURNING date_approved", + )).WithArgs(true, "ok", "admin", 3).WillReturnRows(sqlmock.NewRows([]string{"date_approved"}).AddRow(now)) + + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, secret_id, requesting_user, approved, auth_user, reason_for_request, reason_for_approval, date_requested, date_approved, current FROM requests WHERE id = $1", + )).WithArgs(3).WillReturnRows(sqlmock.NewRows([]string{"id", "secret_id", "requesting_user", "approved", "auth_user", "reason_for_request", "reason_for_approval", "date_requested", "date_approved", "current"}).AddRow(3, 9, "user", true, "admin", "reason", "ok", now, now, true)) + + approved, err := store.ApproveRequest(3, true, "ok", "admin") + require.NoError(t, err) + require.NotNil(t, approved.Approved) + require.True(t, *approved.Approved) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestPostgresStoreGetSecretAndRequests(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + codec := testCodec(t) + store := NewPostgresStoreWithDB(db, codec) + now := time.Now() + + encryptedSecret, err := codec.Encrypt("secret") + require.NoError(t, err) + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, computer_id, secret_type, secret, date_escrowed, rotation_required FROM secrets WHERE id = $1", + )).WithArgs(2).WillReturnRows(sqlmock.NewRows([]string{"id", "computer_id", "secret_type", "secret", "date_escrowed", "rotation_required"}).AddRow(2, 1, "password", encryptedSecret, now, false)) + + secret, err := store.GetSecretByID(2) + require.NoError(t, err) + require.Equal(t, "password", secret.SecretType) + require.Equal(t, "secret", secret.Secret) + + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, computer_id, secret_type, secret, date_escrowed, rotation_required FROM secrets WHERE computer_id = $1 AND secret_type = $2 ORDER BY date_escrowed DESC LIMIT 1", + )).WithArgs(1, "password").WillReturnRows(sqlmock.NewRows([]string{"id", "computer_id", "secret_type", "secret", "date_escrowed", "rotation_required"}).AddRow(3, 1, "password", encryptedSecret, now, true)) + + latest, err := store.GetLatestSecretByComputerAndType(1, "password") + require.NoError(t, err) + require.Equal(t, 3, latest.ID) + require.True(t, latest.RotationRequired) + + encryptedSecret2, err := codec.Encrypt("secret") + require.NoError(t, err) + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, computer_id, secret_type, secret, date_escrowed, rotation_required FROM secrets WHERE computer_id = $1 ORDER BY id", + )).WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "computer_id", "secret_type", "secret", "date_escrowed", "rotation_required"}).AddRow(2, 1, "password", encryptedSecret2, now, false)) + + secrets, err := store.ListSecretsByComputer(1) + require.NoError(t, err) + require.Len(t, secrets, 1) + require.Equal(t, "secret", secrets[0].Secret) + + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, secret_id, requesting_user, approved, auth_user, reason_for_request, reason_for_approval, date_requested, date_approved, current FROM requests WHERE secret_id = $1 ORDER BY id", + )).WithArgs(2).WillReturnRows(sqlmock.NewRows([]string{"id", "secret_id", "requesting_user", "approved", "auth_user", "reason_for_request", "reason_for_approval", "date_requested", "date_approved", "current"}).AddRow(5, 2, "user", nil, nil, "reason", nil, now, nil, true)) + + requests, err := store.ListRequestsBySecret(2) + require.NoError(t, err) + require.Len(t, requests, 1) + + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, secret_id, requesting_user, approved, auth_user, reason_for_request, reason_for_approval, date_requested, date_approved, current FROM requests WHERE current = true AND approved IS NULL ORDER BY id", + )).WillReturnRows(sqlmock.NewRows([]string{"id", "secret_id", "requesting_user", "approved", "auth_user", "reason_for_request", "reason_for_approval", "date_requested", "date_approved", "current"}).AddRow(5, 2, "user", nil, nil, "reason", nil, now, nil, true)) + + outstanding, err := store.ListOutstandingRequests() + require.NoError(t, err) + require.Len(t, outstanding, 1) + + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, secret_id, requesting_user, approved, auth_user, reason_for_request, reason_for_approval, date_requested, date_approved, current FROM requests WHERE id = $1", + )).WithArgs(5).WillReturnRows(sqlmock.NewRows([]string{"id", "secret_id", "requesting_user", "approved", "auth_user", "reason_for_request", "reason_for_approval", "date_requested", "date_approved", "current"}).AddRow(5, 2, "user", nil, nil, "reason", nil, now, nil, true)) + + request, err := store.GetRequestByID(5) + require.NoError(t, err) + require.Equal(t, 2, request.SecretID) + + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestPostgresStoreUserLifecycle(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewPostgresStoreWithDB(db, testCodec(t)) + mock.ExpectQuery(regexp.QuoteMeta( + "INSERT INTO users (username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id", + )).WithArgs("admin", sqlmock.AnyArg(), true, true, true, false, "local").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(7)) + + user, err := store.AddUser("admin", "hash", true, true, true, false, "local") + require.NoError(t, err) + require.Equal(t, 7, user.ID) + + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source FROM users WHERE lower(username) = lower($1)", + )).WithArgs("admin").WillReturnRows(sqlmock.NewRows([]string{"id", "username", "password_hash", "is_staff", "can_approve", "local_login_enabled", "must_reset_password", "auth_source"}).AddRow(7, "admin", "hash", true, true, true, false, "local")) + + loaded, err := store.GetUserByUsername("admin") + require.NoError(t, err) + require.Equal(t, "admin", loaded.Username) + + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source FROM users WHERE id = $1", + )).WithArgs(7).WillReturnRows(sqlmock.NewRows([]string{"id", "username", "password_hash", "is_staff", "can_approve", "local_login_enabled", "must_reset_password", "auth_source"}).AddRow(7, "admin", "hash", true, true, true, false, "local")) + + byID, err := store.GetUserByID(7) + require.NoError(t, err) + require.Equal(t, "admin", byID.Username) + + mock.ExpectQuery(regexp.QuoteMeta( + "UPDATE users SET username = $1, is_staff = $2, can_approve = $3, local_login_enabled = $4, must_reset_password = $5, auth_source = $6 WHERE id = $7 RETURNING id, username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source", + )).WithArgs("updated", false, false, false, true, "local", 7).WillReturnRows(sqlmock.NewRows([]string{"id", "username", "password_hash", "is_staff", "can_approve", "local_login_enabled", "must_reset_password", "auth_source"}).AddRow(7, "updated", "hash", false, false, false, true, "local")) + + updated, err := store.UpdateUser(7, "updated", false, false, false, true, "local") + require.NoError(t, err) + require.Equal(t, "updated", updated.Username) + + mock.ExpectQuery(regexp.QuoteMeta( + "UPDATE users SET password_hash = $1, must_reset_password = $2, local_login_enabled = CASE WHEN $1 IS NULL THEN local_login_enabled ELSE true END WHERE id = $3 RETURNING id, username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source", + )).WithArgs(sqlmock.AnyArg(), false, 7).WillReturnRows(sqlmock.NewRows([]string{"id", "username", "password_hash", "is_staff", "can_approve", "local_login_enabled", "must_reset_password", "auth_source"}).AddRow(7, "updated", "newhash", false, false, true, false, "local")) + + passwordUpdated, err := store.UpdateUserPassword(7, "newhash", false) + require.NoError(t, err) + require.Equal(t, "newhash", passwordUpdated.PasswordHash) + + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source FROM users ORDER BY id", + )).WillReturnRows(sqlmock.NewRows([]string{"id", "username", "password_hash", "is_staff", "can_approve", "local_login_enabled", "must_reset_password", "auth_source"}). + AddRow(7, "admin", "hash", true, true, true, false, "local"). + AddRow(8, "viewer", nil, false, false, false, false, "saml")) + + users, err := store.ListUsers() + require.NoError(t, err) + require.Len(t, users, 2) + + mock.ExpectExec(regexp.QuoteMeta( + "DELETE FROM users WHERE id = $1", + )).WithArgs(7).WillReturnResult(sqlmock.NewResult(0, 1)) + + require.NoError(t, store.DeleteUser(7)) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestPostgresStoreCleanupRequests(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewPostgresStoreWithDB(db, testCodec(t)) + cutoff := time.Now().Add(-7 * 24 * time.Hour) + mock.ExpectExec(regexp.QuoteMeta( + "UPDATE requests SET current = false WHERE current = true AND approved IS NOT NULL AND date_approved < $1", + )).WithArgs(cutoff).WillReturnResult(sqlmock.NewResult(0, 2)) + + updated, err := store.CleanupRequests(cutoff) + require.NoError(t, err) + require.Equal(t, 2, updated) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestPostgresStoreSetSecretRotationRequired(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + codec := testCodec(t) + store := NewPostgresStoreWithDB(db, codec) + now := time.Now() + encrypted, err := codec.Encrypt("secret") + require.NoError(t, err) + mock.ExpectQuery(regexp.QuoteMeta( + "UPDATE secrets SET rotation_required = $1 WHERE id = $2 RETURNING id, computer_id, secret_type, secret, date_escrowed, rotation_required", + )).WithArgs(true, 5).WillReturnRows(sqlmock.NewRows([]string{"id", "computer_id", "secret_type", "secret", "date_escrowed", "rotation_required"}).AddRow(5, 2, "password", encrypted, now, true)) + + updated, err := store.SetSecretRotationRequired(5, true) + require.NoError(t, err) + require.True(t, updated.RotationRequired) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestPostgresStoreAuditEvents(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewPostgresStoreWithDB(db, testCodec(t)) + now := time.Now() + mock.ExpectQuery(regexp.QuoteMeta( + "INSERT INTO audit_events (actor, target_user, action, reason, ip_address) VALUES ($1, $2, $3, $4, $5) RETURNING id, created_at", + )).WithArgs("admin", "user", "password_reset", sqlmock.AnyArg(), sqlmock.AnyArg()). + WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(1, now)) + + event, err := store.AddAuditEvent("admin", "user", "password_reset", "", "") + require.NoError(t, err) + require.Equal(t, 1, event.ID) + + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, actor, target_user, action, reason, ip_address, created_at FROM audit_events ORDER BY created_at DESC, id DESC", + )).WillReturnRows(sqlmock.NewRows([]string{"id", "actor", "target_user", "action", "reason", "ip_address", "created_at"}). + AddRow(2, "admin", "user", "force_reset_enabled", nil, nil, now)) + + events, err := store.ListAuditEvents() + require.NoError(t, err) + require.Len(t, events, 1) + require.Equal(t, "force_reset_enabled", events[0].Action) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestPostgresStoreSearchAuditEvents(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewPostgresStoreWithDB(db, testCodec(t)) + now := time.Now() + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, actor, target_user, action, reason, ip_address, created_at FROM audit_events WHERE actor ILIKE $1 OR target_user ILIKE $1 OR action ILIKE $1 OR COALESCE(reason, '') ILIKE $1 OR COALESCE(ip_address, '') ILIKE $1 ORDER BY created_at DESC, id DESC", + )).WithArgs("%reset%").WillReturnRows(sqlmock.NewRows([]string{"id", "actor", "target_user", "action", "reason", "ip_address", "created_at"}). + AddRow(3, "admin", "user", "password_reset", "reason", "127.0.0.1", now)) + + events, err := store.SearchAuditEvents("reset") + require.NoError(t, err) + require.Len(t, events, 1) + require.Equal(t, "password_reset", events[0].Action) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestPostgresStoreAuditEventsPagingAndCount(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewPostgresStoreWithDB(db, testCodec(t)) + now := time.Now() + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT COUNT(*) FROM audit_events", + )).WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(120)) + + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, actor, target_user, action, reason, ip_address, created_at FROM audit_events ORDER BY created_at DESC, id DESC LIMIT $1 OFFSET $2", + )).WithArgs(50, 50). + WillReturnRows(sqlmock.NewRows([]string{"id", "actor", "target_user", "action", "reason", "ip_address", "created_at"}). + AddRow(10, "admin", "user", "password_reset", nil, nil, now)) + + count, err := store.CountAuditEvents() + require.NoError(t, err) + require.Equal(t, 120, count) + + events, err := store.ListAuditEventsPaged(50, 50) + require.NoError(t, err) + require.Len(t, events, 1) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestPostgresStoreCountSearchAuditEvents(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewPostgresStoreWithDB(db, testCodec(t)) + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT COUNT(*) FROM audit_events WHERE actor ILIKE $1 OR target_user ILIKE $1 OR action ILIKE $1 OR COALESCE(reason, '') ILIKE $1 OR COALESCE(ip_address, '') ILIKE $1", + )).WithArgs("%reset%").WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(5)) + + count, err := store.CountSearchAuditEvents("reset") + require.NoError(t, err) + require.Equal(t, 5, count) + require.NoError(t, mock.ExpectationsWereMet()) +} diff --git a/internal/store/sqlite.go b/internal/store/sqlite.go new file mode 100644 index 0000000..8238797 --- /dev/null +++ b/internal/store/sqlite.go @@ -0,0 +1,761 @@ +package store + +import ( + "context" + "database/sql" + "fmt" + "time" + + _ "modernc.org/sqlite" +) + +type SQLiteStore struct { + db *sql.DB + codec SecretCodec +} + +func (s *SQLiteStore) DB() *sql.DB { + return s.db +} + +func NewSQLiteStore(dsn string, codec SecretCodec) (*SQLiteStore, error) { + if dsn == "" { + return nil, fmt.Errorf("sqlite dsn is required") + } + db, err := sql.Open("sqlite", dsn) + if err != nil { + return nil, fmt.Errorf("open db: %w", err) + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := db.PingContext(ctx); err != nil { + return nil, fmt.Errorf("ping db: %w", err) + } + if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { + return nil, fmt.Errorf("enable foreign keys: %w", err) + } + return &SQLiteStore{db: db, codec: codec}, nil +} + +func NewSQLiteStoreWithDB(db *sql.DB, codec SecretCodec) *SQLiteStore { + return &SQLiteStore{db: db, codec: codec} +} + +func (s *SQLiteStore) AddComputer(serial, username, computerName string) (*Computer, error) { + now := time.Now() + var id int + var lastCheckin time.Time + row := s.db.QueryRow( + "INSERT INTO computers (serial, username, computername, last_checkin) VALUES (?, ?, ?, ?) RETURNING id, last_checkin", + serial, username, computerName, now, + ) + if err := row.Scan(&id, &lastCheckin); err != nil { + return nil, fmt.Errorf("insert computer: %w", err) + } + return &Computer{ + ID: id, + Serial: serial, + Username: username, + ComputerName: computerName, + LastCheckin: lastCheckin, + }, nil +} + +func (s *SQLiteStore) UpsertComputer(serial, username, computerName string, lastCheckin time.Time) (*Computer, error) { + var id int + var stored time.Time + row := s.db.QueryRow( + "INSERT INTO computers (serial, username, computername, last_checkin) VALUES (?, ?, ?, ?) ON CONFLICT(serial) DO UPDATE SET username = excluded.username, computername = excluded.computername, last_checkin = excluded.last_checkin RETURNING id, last_checkin", + serial, username, computerName, lastCheckin, + ) + if err := row.Scan(&id, &stored); err != nil { + return nil, fmt.Errorf("upsert computer: %w", err) + } + return &Computer{ + ID: id, + Serial: serial, + Username: username, + ComputerName: computerName, + LastCheckin: stored, + }, nil +} + +func (s *SQLiteStore) ListComputers() ([]*Computer, error) { + rows, err := s.db.Query("SELECT id, serial, username, computername, last_checkin FROM computers ORDER BY id") + if err != nil { + return nil, fmt.Errorf("list computers: %w", err) + } + defer rows.Close() + + computers := make([]*Computer, 0) + for rows.Next() { + var computer Computer + var lastCheckin sql.NullTime + if err := rows.Scan(&computer.ID, &computer.Serial, &computer.Username, &computer.ComputerName, &lastCheckin); err != nil { + return nil, fmt.Errorf("scan computer: %w", err) + } + if lastCheckin.Valid { + computer.LastCheckin = lastCheckin.Time + } + computers = append(computers, &computer) + } + return computers, rows.Err() +} + +func (s *SQLiteStore) GetComputerByID(id int) (*Computer, error) { + var computer Computer + var lastCheckin sql.NullTime + row := s.db.QueryRow("SELECT id, serial, username, computername, last_checkin FROM computers WHERE id = ?", id) + if err := row.Scan(&computer.ID, &computer.Serial, &computer.Username, &computer.ComputerName, &lastCheckin); err != nil { + if err == sql.ErrNoRows { + return nil, ErrNotFound + } + return nil, fmt.Errorf("get computer by id: %w", err) + } + if lastCheckin.Valid { + computer.LastCheckin = lastCheckin.Time + } + return &computer, nil +} + +func (s *SQLiteStore) GetComputerBySerial(serial string) (*Computer, error) { + var computer Computer + var lastCheckin sql.NullTime + row := s.db.QueryRow("SELECT id, serial, username, computername, last_checkin FROM computers WHERE lower(serial) = lower(?)", serial) + if err := row.Scan(&computer.ID, &computer.Serial, &computer.Username, &computer.ComputerName, &lastCheckin); err != nil { + if err == sql.ErrNoRows { + return nil, ErrNotFound + } + return nil, fmt.Errorf("get computer by serial: %w", err) + } + if lastCheckin.Valid { + computer.LastCheckin = lastCheckin.Time + } + return &computer, nil +} + +func (s *SQLiteStore) AddSecret(computerID int, secretType, secret string, rotationRequired bool) (*Secret, bool, error) { + if s.codec == nil { + return nil, false, ErrMissingCodec + } + + // Check for duplicate secret (matching Django behavior) + // If the same secret value already exists for this computer/type and rotation is not required, skip insert + if !rotationRequired { + existing, err := s.findExistingSecret(computerID, secretType, secret) + if err != nil { + return nil, false, err + } + if existing != nil { + return existing, false, nil + } + } + + encrypted, err := s.codec.Encrypt(secret) + if err != nil { + return nil, false, err + } + now := time.Now() + var id int + var dateEscrowed time.Time + row := s.db.QueryRow( + "INSERT INTO secrets (computer_id, secret_type, secret, date_escrowed, rotation_required) VALUES (?, ?, ?, ?, ?) RETURNING id, date_escrowed", + computerID, secretType, encrypted, now, rotationRequired, + ) + if err := row.Scan(&id, &dateEscrowed); err != nil { + return nil, false, fmt.Errorf("insert secret: %w", err) + } + decrypted, err := s.decryptSecret(&Secret{ + ID: id, + ComputerID: computerID, + SecretType: secretType, + Secret: encrypted, + DateEscrowed: dateEscrowed, + RotationRequired: rotationRequired, + }) + if err != nil { + return nil, false, err + } + return decrypted, true, nil +} + +// findExistingSecret checks if the same secret value already exists for this computer/type +func (s *SQLiteStore) findExistingSecret(computerID int, secretType, plainSecret string) (*Secret, error) { + rows, err := s.db.Query( + "SELECT id, computer_id, secret_type, secret, date_escrowed, rotation_required FROM secrets WHERE computer_id = ? AND secret_type = ?", + computerID, secretType, + ) + if err != nil { + return nil, fmt.Errorf("find existing secret: %w", err) + } + defer rows.Close() + + for rows.Next() { + secret, err := scanSecret(rows) + if err != nil { + return nil, err + } + decrypted, err := s.decryptSecret(secret) + if err != nil { + return nil, err + } + if decrypted.Secret == plainSecret { + return decrypted, nil + } + } + return nil, nil +} + +func (s *SQLiteStore) ListSecretsByComputer(computerID int) ([]*Secret, error) { + rows, err := s.db.Query("SELECT id, computer_id, secret_type, secret, date_escrowed, rotation_required FROM secrets WHERE computer_id = ? ORDER BY id", computerID) + if err != nil { + return nil, fmt.Errorf("list secrets: %w", err) + } + defer rows.Close() + + secrets := make([]*Secret, 0) + for rows.Next() { + secret, err := scanSecret(rows) + if err != nil { + return nil, err + } + decrypted, err := s.decryptSecret(secret) + if err != nil { + return nil, err + } + secrets = append(secrets, decrypted) + } + return secrets, rows.Err() +} + +func (s *SQLiteStore) GetSecretByID(id int) (*Secret, error) { + row := s.db.QueryRow("SELECT id, computer_id, secret_type, secret, date_escrowed, rotation_required FROM secrets WHERE id = ?", id) + secret, err := scanSecret(row) + if err != nil { + if err == sql.ErrNoRows { + return nil, ErrNotFound + } + return nil, fmt.Errorf("get secret: %w", err) + } + return s.decryptSecret(secret) +} + +func (s *SQLiteStore) GetLatestSecretByComputerAndType(computerID int, secretType string) (*Secret, error) { + row := s.db.QueryRow( + "SELECT id, computer_id, secret_type, secret, date_escrowed, rotation_required FROM secrets WHERE computer_id = ? AND secret_type = ? ORDER BY date_escrowed DESC LIMIT 1", + computerID, secretType, + ) + secret, err := scanSecret(row) + if err != nil { + if err == sql.ErrNoRows { + return nil, ErrNotFound + } + return nil, fmt.Errorf("get latest secret: %w", err) + } + return s.decryptSecret(secret) +} + +func (s *SQLiteStore) AddRequest(secretID int, requestingUser, reason string, approvedBy string, approved *bool) (*Request, error) { + now := time.Now() + var id int + var dateRequested time.Time + var dateApproved sql.NullTime + var approvedValue sql.NullBool + if approved != nil { + approvedValue.Valid = true + approvedValue.Bool = *approved + dateApproved.Valid = true + dateApproved.Time = now + } + row := s.db.QueryRow( + "INSERT INTO requests (secret_id, requesting_user, approved, auth_user, reason_for_request, reason_for_approval, date_requested, date_approved, current) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) RETURNING id, date_requested, date_approved", + secretID, requestingUser, approvedValue, nullableString(approvedBy), reason, sql.NullString{}, now, dateApproved, true, + ) + if err := row.Scan(&id, &dateRequested, &dateApproved); err != nil { + return nil, fmt.Errorf("insert request: %w", err) + } + var approvedPtr *bool + if approvedValue.Valid { + value := approvedValue.Bool + approvedPtr = &value + } + var dateApprovedPtr *time.Time + if dateApproved.Valid { + dateApprovedPtr = &dateApproved.Time + } + return &Request{ + ID: id, + SecretID: secretID, + RequestingUser: requestingUser, + Approved: approvedPtr, + AuthUser: approvedBy, + ReasonForRequest: reason, + ReasonForApproval: "", + DateRequested: dateRequested, + DateApproved: dateApprovedPtr, + Current: true, + }, nil +} + +func (s *SQLiteStore) ListRequestsBySecret(secretID int) ([]*Request, error) { + rows, err := s.db.Query("SELECT id, secret_id, requesting_user, approved, auth_user, reason_for_request, reason_for_approval, date_requested, date_approved, current FROM requests WHERE secret_id = ? ORDER BY id", secretID) + if err != nil { + return nil, fmt.Errorf("list requests: %w", err) + } + defer rows.Close() + + requests := make([]*Request, 0) + for rows.Next() { + request, err := scanRequest(rows) + if err != nil { + return nil, err + } + requests = append(requests, request) + } + return requests, rows.Err() +} + +func (s *SQLiteStore) ListOutstandingRequests() ([]*Request, error) { + rows, err := s.db.Query("SELECT id, secret_id, requesting_user, approved, auth_user, reason_for_request, reason_for_approval, date_requested, date_approved, current FROM requests WHERE current = true AND approved IS NULL ORDER BY id") + if err != nil { + return nil, fmt.Errorf("list outstanding requests: %w", err) + } + defer rows.Close() + + requests := make([]*Request, 0) + for rows.Next() { + request, err := scanRequest(rows) + if err != nil { + return nil, err + } + requests = append(requests, request) + } + return requests, rows.Err() +} + +func (s *SQLiteStore) GetRequestByID(id int) (*Request, error) { + row := s.db.QueryRow("SELECT id, secret_id, requesting_user, approved, auth_user, reason_for_request, reason_for_approval, date_requested, date_approved, current FROM requests WHERE id = ?", id) + request, err := scanRequest(row) + if err != nil { + if err == sql.ErrNoRows { + return nil, ErrNotFound + } + return nil, err + } + return request, nil +} + +func (s *SQLiteStore) ApproveRequest(requestID int, approved bool, reason, approver string) (*Request, error) { + dateApproved := time.Now() + result, err := s.db.Exec( + "UPDATE requests SET approved = ?, reason_for_approval = ?, auth_user = ?, date_approved = ? WHERE id = ?", + approved, reason, approver, dateApproved, requestID, + ) + if err != nil { + return nil, fmt.Errorf("approve request: %w", err) + } + affected, err := result.RowsAffected() + if err != nil { + return nil, fmt.Errorf("approve request: %w", err) + } + if affected == 0 { + return nil, ErrNotFound + } + updated, err := s.GetRequestByID(requestID) + if err != nil { + return nil, err + } + if updated.DateApproved == nil { + updated.DateApproved = &dateApproved + } + return updated, nil +} + +func (s *SQLiteStore) AddUser(username, passwordHash string, isStaff, canApprove, localLoginEnabled, mustResetPassword bool, authSource string) (*User, error) { + var id int + row := s.db.QueryRow( + "INSERT INTO users (username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source) VALUES (?, ?, ?, ?, ?, ?, ?) RETURNING id", + username, nullableString(passwordHash), isStaff, canApprove, localLoginEnabled, mustResetPassword, authSource, + ) + if err := row.Scan(&id); err != nil { + return nil, fmt.Errorf("insert user: %w", err) + } + return &User{ + ID: id, + Username: username, + PasswordHash: passwordHash, + IsStaff: isStaff, + CanApprove: canApprove, + LocalLoginEnabled: localLoginEnabled, + MustResetPassword: mustResetPassword, + AuthSource: authSource, + }, nil +} + +func (s *SQLiteStore) GetUserByUsername(username string) (*User, error) { + var user User + var passwordHash sql.NullString + row := s.db.QueryRow("SELECT id, username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source FROM users WHERE lower(username) = lower(?)", username) + if err := row.Scan(&user.ID, &user.Username, &passwordHash, &user.IsStaff, &user.CanApprove, &user.LocalLoginEnabled, &user.MustResetPassword, &user.AuthSource); err != nil { + if err == sql.ErrNoRows { + return nil, ErrNotFound + } + return nil, fmt.Errorf("get user by username: %w", err) + } + if passwordHash.Valid { + user.PasswordHash = passwordHash.String + } + return &user, nil +} + +func (s *SQLiteStore) ListUsers() ([]*User, error) { + rows, err := s.db.Query("SELECT id, username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source FROM users ORDER BY id") + if err != nil { + return nil, fmt.Errorf("list users: %w", err) + } + defer rows.Close() + + users := make([]*User, 0) + for rows.Next() { + var user User + var passwordHash sql.NullString + if err := rows.Scan(&user.ID, &user.Username, &passwordHash, &user.IsStaff, &user.CanApprove, &user.LocalLoginEnabled, &user.MustResetPassword, &user.AuthSource); err != nil { + return nil, fmt.Errorf("scan user: %w", err) + } + if passwordHash.Valid { + user.PasswordHash = passwordHash.String + } + users = append(users, &user) + } + return users, rows.Err() +} + +func (s *SQLiteStore) GetUserByID(id int) (*User, error) { + var user User + var passwordHash sql.NullString + row := s.db.QueryRow("SELECT id, username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source FROM users WHERE id = ?", id) + if err := row.Scan(&user.ID, &user.Username, &passwordHash, &user.IsStaff, &user.CanApprove, &user.LocalLoginEnabled, &user.MustResetPassword, &user.AuthSource); err != nil { + if err == sql.ErrNoRows { + return nil, ErrNotFound + } + return nil, fmt.Errorf("get user by id: %w", err) + } + if passwordHash.Valid { + user.PasswordHash = passwordHash.String + } + return &user, nil +} + +func (s *SQLiteStore) UpdateUser(id int, username string, isStaff, canApprove, localLoginEnabled, mustResetPassword bool, authSource string) (*User, error) { + var user User + var passwordHash sql.NullString + row := s.db.QueryRow( + "UPDATE users SET username = ?, is_staff = ?, can_approve = ?, local_login_enabled = ?, must_reset_password = ?, auth_source = ? WHERE id = ? RETURNING id, username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source", + username, isStaff, canApprove, localLoginEnabled, mustResetPassword, authSource, id, + ) + if err := row.Scan(&user.ID, &user.Username, &passwordHash, &user.IsStaff, &user.CanApprove, &user.LocalLoginEnabled, &user.MustResetPassword, &user.AuthSource); err != nil { + if err == sql.ErrNoRows { + return nil, ErrNotFound + } + return nil, fmt.Errorf("update user: %w", err) + } + if passwordHash.Valid { + user.PasswordHash = passwordHash.String + } + return &user, nil +} + +func (s *SQLiteStore) UpdateUserPassword(id int, passwordHash string, mustResetPassword bool) (*User, error) { + var user User + var hash sql.NullString + row := s.db.QueryRow( + "UPDATE users SET password_hash = ?, must_reset_password = ?, local_login_enabled = CASE WHEN ? IS NULL THEN local_login_enabled ELSE 1 END WHERE id = ? RETURNING id, username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source", + nullableString(passwordHash), mustResetPassword, nullableString(passwordHash), id, + ) + if err := row.Scan(&user.ID, &user.Username, &hash, &user.IsStaff, &user.CanApprove, &user.LocalLoginEnabled, &user.MustResetPassword, &user.AuthSource); err != nil { + if err == sql.ErrNoRows { + return nil, ErrNotFound + } + return nil, fmt.Errorf("update user password: %w", err) + } + if hash.Valid { + user.PasswordHash = hash.String + } + return &user, nil +} + +func (s *SQLiteStore) DeleteUser(id int) error { + result, err := s.db.Exec("DELETE FROM users WHERE id = ?", id) + if err != nil { + return fmt.Errorf("delete user: %w", err) + } + affected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("delete user: %w", err) + } + if affected == 0 { + return ErrNotFound + } + return nil +} + +func (s *SQLiteStore) CleanupRequests(approvedBefore time.Time) (int, error) { + result, err := s.db.Exec( + "UPDATE requests SET current = 0 WHERE current = 1 AND approved IS NOT NULL AND date_approved < ?", + approvedBefore, + ) + if err != nil { + return 0, fmt.Errorf("cleanup requests: %w", err) + } + affected, err := result.RowsAffected() + if err != nil { + return 0, fmt.Errorf("cleanup requests: %w", err) + } + return int(affected), nil +} + +func (s *SQLiteStore) SetSecretRotationRequired(secretID int, rotationRequired bool) (*Secret, error) { + result, err := s.db.Exec( + "UPDATE secrets SET rotation_required = ? WHERE id = ?", + rotationRequired, secretID, + ) + if err != nil { + return nil, fmt.Errorf("update secret rotation: %w", err) + } + affected, err := result.RowsAffected() + if err != nil { + return nil, fmt.Errorf("update secret rotation: %w", err) + } + if affected == 0 { + return nil, ErrNotFound + } + return s.GetSecretByID(secretID) +} + +func (s *SQLiteStore) AddAuditEvent(actor, targetUser, action, reason, ipAddress string) (*AuditEvent, error) { + var id int + var createdAt time.Time + row := s.db.QueryRow( + "INSERT INTO audit_events (actor, target_user, action, reason, ip_address) VALUES (?, ?, ?, ?, ?) RETURNING id, created_at", + actor, targetUser, action, nullableString(reason), nullableString(ipAddress), + ) + if err := row.Scan(&id, &createdAt); err != nil { + return nil, fmt.Errorf("insert audit event: %w", err) + } + return &AuditEvent{ + ID: id, + Actor: actor, + TargetUser: targetUser, + Action: action, + Reason: reason, + IPAddress: ipAddress, + CreatedAt: createdAt, + }, nil +} + +func (s *SQLiteStore) ListAuditEvents() ([]*AuditEvent, error) { + rows, err := s.db.Query("SELECT id, actor, target_user, action, reason, ip_address, created_at FROM audit_events ORDER BY created_at DESC, id DESC") + if err != nil { + return nil, fmt.Errorf("list audit events: %w", err) + } + defer rows.Close() + + return scanAuditEventsSQLite(rows) +} + +func (s *SQLiteStore) SearchAuditEvents(query string) ([]*AuditEvent, error) { + pattern := "%" + query + "%" + rows, err := s.db.Query( + `SELECT id, actor, target_user, action, reason, ip_address, created_at + FROM audit_events + WHERE lower(actor) LIKE lower(?) + OR lower(target_user) LIKE lower(?) + OR lower(action) LIKE lower(?) + OR lower(COALESCE(reason, '')) LIKE lower(?) + OR lower(COALESCE(ip_address, '')) LIKE lower(?) + ORDER BY created_at DESC, id DESC`, + pattern, pattern, pattern, pattern, pattern, + ) + if err != nil { + return nil, fmt.Errorf("search audit events: %w", err) + } + defer rows.Close() + + return scanAuditEventsSQLite(rows) +} + +func (s *SQLiteStore) ListAuditEventsPaged(limit, offset int) ([]*AuditEvent, error) { + rows, err := s.db.Query( + `SELECT id, actor, target_user, action, reason, ip_address, created_at + FROM audit_events + ORDER BY created_at DESC, id DESC + LIMIT ? OFFSET ?`, + limit, offset, + ) + if err != nil { + return nil, fmt.Errorf("list audit events paged: %w", err) + } + defer rows.Close() + + return scanAuditEventsSQLite(rows) +} + +func (s *SQLiteStore) SearchAuditEventsPaged(query string, limit, offset int) ([]*AuditEvent, error) { + pattern := "%" + query + "%" + rows, err := s.db.Query( + `SELECT id, actor, target_user, action, reason, ip_address, created_at + FROM audit_events + WHERE lower(actor) LIKE lower(?) + OR lower(target_user) LIKE lower(?) + OR lower(action) LIKE lower(?) + OR lower(COALESCE(reason, '')) LIKE lower(?) + OR lower(COALESCE(ip_address, '')) LIKE lower(?) + ORDER BY created_at DESC, id DESC + LIMIT ? OFFSET ?`, + pattern, pattern, pattern, pattern, pattern, limit, offset, + ) + if err != nil { + return nil, fmt.Errorf("search audit events paged: %w", err) + } + defer rows.Close() + + return scanAuditEventsSQLite(rows) +} + +func (s *SQLiteStore) CountAuditEvents() (int, error) { + var count int + if err := s.db.QueryRow(`SELECT COUNT(*) FROM audit_events`).Scan(&count); err != nil { + return 0, fmt.Errorf("count audit events: %w", err) + } + return count, nil +} + +func (s *SQLiteStore) CountSearchAuditEvents(query string) (int, error) { + pattern := "%" + query + "%" + var count int + if err := s.db.QueryRow( + `SELECT COUNT(*) + FROM audit_events + WHERE lower(actor) LIKE lower(?) + OR lower(target_user) LIKE lower(?) + OR lower(action) LIKE lower(?) + OR lower(COALESCE(reason, '')) LIKE lower(?) + OR lower(COALESCE(ip_address, '')) LIKE lower(?)`, + pattern, pattern, pattern, pattern, pattern, + ).Scan(&count); err != nil { + return 0, fmt.Errorf("count audit events search: %w", err) + } + return count, nil +} + +func (s *SQLiteStore) decryptSecret(secret *Secret) (*Secret, error) { + if s.codec == nil { + return nil, ErrMissingCodec + } + plaintext, err := s.codec.Decrypt(secret.Secret) + if err != nil { + return nil, err + } + clone := *secret + clone.Secret = plaintext + return &clone, nil +} + +func scanSecret(row scanner) (*Secret, error) { + var secret Secret + var rotation int + if err := row.Scan(&secret.ID, &secret.ComputerID, &secret.SecretType, &secret.Secret, &secret.DateEscrowed, &rotation); err != nil { + return nil, err + } + secret.RotationRequired = rotation != 0 + return &secret, nil +} + +func scanAuditEventsSQLite(rows *sql.Rows) ([]*AuditEvent, error) { + events := make([]*AuditEvent, 0) + for rows.Next() { + var event AuditEvent + var reason sql.NullString + var ipAddress sql.NullString + if err := rows.Scan(&event.ID, &event.Actor, &event.TargetUser, &event.Action, &reason, &ipAddress, &event.CreatedAt); err != nil { + return nil, fmt.Errorf("scan audit event: %w", err) + } + if reason.Valid { + event.Reason = reason.String + } + if ipAddress.Valid { + event.IPAddress = ipAddress.String + } + events = append(events, &event) + } + return events, rows.Err() +} + +func (s *SQLiteStore) IsEmpty() (bool, error) { + tables := []string{"computers", "secrets", "requests", "users"} + for _, table := range tables { + var count int + query := fmt.Sprintf("SELECT COUNT(*) FROM %s", table) + if err := s.db.QueryRow(query).Scan(&count); err != nil { + return false, fmt.Errorf("count %s: %w", table, err) + } + if count > 0 { + return false, nil + } + } + return true, nil +} + +func (s *SQLiteStore) ImportComputer(id int, serial, username, computerName string, lastCheckin time.Time) error { + _, err := s.db.Exec( + "INSERT INTO computers (id, serial, username, computername, last_checkin) VALUES (?, ?, ?, ?, ?)", + id, serial, username, computerName, lastCheckin, + ) + if err != nil { + return fmt.Errorf("import computer: %w", err) + } + return nil +} + +func (s *SQLiteStore) ImportSecret(id, computerID int, secretType, encryptedSecret string, dateEscrowed time.Time, rotationRequired bool) error { + _, err := s.db.Exec( + "INSERT INTO secrets (id, computer_id, secret_type, secret, date_escrowed, rotation_required) VALUES (?, ?, ?, ?, ?, ?)", + id, computerID, secretType, encryptedSecret, dateEscrowed, rotationRequired, + ) + if err != nil { + return fmt.Errorf("import secret: %w", err) + } + return nil +} + +func (s *SQLiteStore) ImportRequest(id, secretID int, requestingUser string, approved *bool, authUser, reasonForRequest, reasonForApproval string, dateRequested time.Time, dateApproved *time.Time, current bool) error { + var approvedValue sql.NullBool + if approved != nil { + approvedValue.Valid = true + approvedValue.Bool = *approved + } + var dateApprovedValue sql.NullTime + if dateApproved != nil { + dateApprovedValue.Valid = true + dateApprovedValue.Time = *dateApproved + } + _, err := s.db.Exec( + "INSERT INTO requests (id, secret_id, requesting_user, approved, auth_user, reason_for_request, reason_for_approval, date_requested, date_approved, current) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + id, secretID, requestingUser, approvedValue, nullableString(authUser), reasonForRequest, nullableString(reasonForApproval), dateRequested, dateApprovedValue, current, + ) + if err != nil { + return fmt.Errorf("import request: %w", err) + } + return nil +} + +func (s *SQLiteStore) ImportUser(id int, username, passwordHash string, isStaff, canApprove, localLoginEnabled, mustResetPassword bool, authSource string) error { + _, err := s.db.Exec( + "INSERT INTO users (id, username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + id, username, nullableString(passwordHash), isStaff, canApprove, localLoginEnabled, mustResetPassword, authSource, + ) + if err != nil { + return fmt.Errorf("import user: %w", err) + } + return nil +} diff --git a/internal/store/sqlite_test.go b/internal/store/sqlite_test.go new file mode 100644 index 0000000..5797749 --- /dev/null +++ b/internal/store/sqlite_test.go @@ -0,0 +1,420 @@ +package store + +import ( + "path/filepath" + "regexp" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/require" +) + +func TestSQLiteStoreAddComputer(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + codec := testCodec(t) + store := NewSQLiteStoreWithDB(db, codec) + lastCheckin := time.Now() + mock.ExpectQuery(regexp.QuoteMeta( + "INSERT INTO computers (serial, username, computername, last_checkin) VALUES (?, ?, ?, ?) RETURNING id, last_checkin", + )).WithArgs("SERIAL", "user", "Mac", sqlmock.AnyArg()).WillReturnRows(sqlmock.NewRows([]string{"id", "last_checkin"}).AddRow(1, lastCheckin)) + + computer, err := store.AddComputer("SERIAL", "user", "Mac") + require.NoError(t, err) + require.Equal(t, 1, computer.ID) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestNewSQLiteStoreRequiresDSN(t *testing.T) { + _, err := NewSQLiteStore("", testCodec(t)) + require.Error(t, err) +} + +func TestNewSQLiteStoreOpensDatabase(t *testing.T) { + path := filepath.Join(t.TempDir(), "crypt.db") + + store, err := NewSQLiteStore(path, testCodec(t)) + require.NoError(t, err) + require.NotNil(t, store) + require.NoError(t, store.db.Close()) +} + +func TestSQLiteStoreDB(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewSQLiteStoreWithDB(db, testCodec(t)) + require.NotNil(t, store.DB()) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestSQLiteStoreUpsertComputer(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewSQLiteStoreWithDB(db, testCodec(t)) + lastCheckin := time.Now() + mock.ExpectQuery(regexp.QuoteMeta( + "INSERT INTO computers (serial, username, computername, last_checkin) VALUES (?, ?, ?, ?) ON CONFLICT(serial) DO UPDATE SET username = excluded.username, computername = excluded.computername, last_checkin = excluded.last_checkin RETURNING id, last_checkin", + )).WithArgs("SERIAL", "user", "Mac", lastCheckin).WillReturnRows(sqlmock.NewRows([]string{"id", "last_checkin"}).AddRow(1, lastCheckin)) + + computer, err := store.UpsertComputer("SERIAL", "user", "Mac", lastCheckin) + require.NoError(t, err) + require.Equal(t, 1, computer.ID) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestSQLiteStoreListComputers(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewSQLiteStoreWithDB(db, testCodec(t)) + now := time.Now() + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, serial, username, computername, last_checkin FROM computers ORDER BY id", + )).WillReturnRows(sqlmock.NewRows([]string{"id", "serial", "username", "computername", "last_checkin"}).AddRow(1, "SERIAL", "user", "Mac", now)) + + computers, err := store.ListComputers() + require.NoError(t, err) + require.Len(t, computers, 1) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestSQLiteStoreGetComputerByID(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewSQLiteStoreWithDB(db, testCodec(t)) + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, serial, username, computername, last_checkin FROM computers WHERE id = ?", + )).WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "serial", "username", "computername", "last_checkin"}).AddRow(1, "SERIAL", "user", "Mac", time.Now())) + + computer, err := store.GetComputerByID(1) + require.NoError(t, err) + require.Equal(t, "SERIAL", computer.Serial) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestSQLiteStoreAddSecret(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewSQLiteStoreWithDB(db, testCodec(t)) + now := time.Now() + // Expect the duplicate check query first (returns no rows = no duplicate) + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, computer_id, secret_type, secret, date_escrowed, rotation_required FROM secrets WHERE computer_id = ? AND secret_type = ?", + )).WithArgs(1, "password").WillReturnRows(sqlmock.NewRows([]string{"id", "computer_id", "secret_type", "secret", "date_escrowed", "rotation_required"})) + // Then expect the insert + mock.ExpectQuery(regexp.QuoteMeta( + "INSERT INTO secrets (computer_id, secret_type, secret, date_escrowed, rotation_required) VALUES (?, ?, ?, ?, ?) RETURNING id, date_escrowed", + )).WithArgs(1, "password", sqlmock.AnyArg(), sqlmock.AnyArg(), false).WillReturnRows(sqlmock.NewRows([]string{"id", "date_escrowed"}).AddRow(5, now)) + + secret, isNew, err := store.AddSecret(1, "password", "secret", false) + require.NoError(t, err) + require.Equal(t, 5, secret.ID) + require.True(t, isNew) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestSQLiteStoreAddRequestAndApprove(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewSQLiteStoreWithDB(db, testCodec(t)) + now := time.Now() + mock.ExpectQuery(regexp.QuoteMeta( + "INSERT INTO requests (secret_id, requesting_user, approved, auth_user, reason_for_request, reason_for_approval, date_requested, date_approved, current) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) RETURNING id, date_requested, date_approved", + )).WithArgs(9, "user", sqlmock.AnyArg(), sqlmock.AnyArg(), "reason", sqlmock.AnyArg(), sqlmock.AnyArg(), sqlmock.AnyArg(), true).WillReturnRows(sqlmock.NewRows([]string{"id", "date_requested", "date_approved"}).AddRow(3, now, nil)) + + request, err := store.AddRequest(9, "user", "reason", "", nil) + require.NoError(t, err) + require.Equal(t, 3, request.ID) + + mock.ExpectExec(regexp.QuoteMeta( + "UPDATE requests SET approved = ?, reason_for_approval = ?, auth_user = ?, date_approved = ? WHERE id = ?", + )).WithArgs(true, "ok", "admin", sqlmock.AnyArg(), 3).WillReturnResult(sqlmock.NewResult(0, 1)) + + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, secret_id, requesting_user, approved, auth_user, reason_for_request, reason_for_approval, date_requested, date_approved, current FROM requests WHERE id = ?", + )).WithArgs(3).WillReturnRows(sqlmock.NewRows([]string{"id", "secret_id", "requesting_user", "approved", "auth_user", "reason_for_request", "reason_for_approval", "date_requested", "date_approved", "current"}).AddRow(3, 9, "user", true, "admin", "reason", "ok", now, now, true)) + + approved, err := store.ApproveRequest(3, true, "ok", "admin") + require.NoError(t, err) + require.NotNil(t, approved.Approved) + require.True(t, *approved.Approved) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestSQLiteStoreGetSecretAndRequests(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + codec := testCodec(t) + store := NewSQLiteStoreWithDB(db, codec) + now := time.Now() + + encryptedSecret, err := codec.Encrypt("secret") + require.NoError(t, err) + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, computer_id, secret_type, secret, date_escrowed, rotation_required FROM secrets WHERE id = ?", + )).WithArgs(2).WillReturnRows(sqlmock.NewRows([]string{"id", "computer_id", "secret_type", "secret", "date_escrowed", "rotation_required"}).AddRow(2, 1, "password", encryptedSecret, now, 0)) + + secret, err := store.GetSecretByID(2) + require.NoError(t, err) + require.Equal(t, "password", secret.SecretType) + require.Equal(t, "secret", secret.Secret) + + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, computer_id, secret_type, secret, date_escrowed, rotation_required FROM secrets WHERE computer_id = ? AND secret_type = ? ORDER BY date_escrowed DESC LIMIT 1", + )).WithArgs(1, "password").WillReturnRows(sqlmock.NewRows([]string{"id", "computer_id", "secret_type", "secret", "date_escrowed", "rotation_required"}).AddRow(3, 1, "password", encryptedSecret, now, 1)) + + latest, err := store.GetLatestSecretByComputerAndType(1, "password") + require.NoError(t, err) + require.Equal(t, 3, latest.ID) + require.True(t, latest.RotationRequired) + + encryptedSecret2, err := codec.Encrypt("secret") + require.NoError(t, err) + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, computer_id, secret_type, secret, date_escrowed, rotation_required FROM secrets WHERE computer_id = ? ORDER BY id", + )).WithArgs(1).WillReturnRows(sqlmock.NewRows([]string{"id", "computer_id", "secret_type", "secret", "date_escrowed", "rotation_required"}).AddRow(2, 1, "password", encryptedSecret2, now, 0)) + + secrets, err := store.ListSecretsByComputer(1) + require.NoError(t, err) + require.Len(t, secrets, 1) + require.Equal(t, "secret", secrets[0].Secret) + + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, secret_id, requesting_user, approved, auth_user, reason_for_request, reason_for_approval, date_requested, date_approved, current FROM requests WHERE secret_id = ? ORDER BY id", + )).WithArgs(2).WillReturnRows(sqlmock.NewRows([]string{"id", "secret_id", "requesting_user", "approved", "auth_user", "reason_for_request", "reason_for_approval", "date_requested", "date_approved", "current"}).AddRow(5, 2, "user", nil, nil, "reason", nil, now, nil, true)) + + requests, err := store.ListRequestsBySecret(2) + require.NoError(t, err) + require.Len(t, requests, 1) + + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, secret_id, requesting_user, approved, auth_user, reason_for_request, reason_for_approval, date_requested, date_approved, current FROM requests WHERE current = true AND approved IS NULL ORDER BY id", + )).WillReturnRows(sqlmock.NewRows([]string{"id", "secret_id", "requesting_user", "approved", "auth_user", "reason_for_request", "reason_for_approval", "date_requested", "date_approved", "current"}).AddRow(5, 2, "user", nil, nil, "reason", nil, now, nil, true)) + + outstanding, err := store.ListOutstandingRequests() + require.NoError(t, err) + require.Len(t, outstanding, 1) + + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, secret_id, requesting_user, approved, auth_user, reason_for_request, reason_for_approval, date_requested, date_approved, current FROM requests WHERE id = ?", + )).WithArgs(5).WillReturnRows(sqlmock.NewRows([]string{"id", "secret_id", "requesting_user", "approved", "auth_user", "reason_for_request", "reason_for_approval", "date_requested", "date_approved", "current"}).AddRow(5, 2, "user", nil, nil, "reason", nil, now, nil, true)) + + request, err := store.GetRequestByID(5) + require.NoError(t, err) + require.Equal(t, 2, request.SecretID) + + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestSQLiteStoreUserLifecycle(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewSQLiteStoreWithDB(db, testCodec(t)) + mock.ExpectQuery(regexp.QuoteMeta( + "INSERT INTO users (username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source) VALUES (?, ?, ?, ?, ?, ?, ?) RETURNING id", + )).WithArgs("admin", sqlmock.AnyArg(), true, true, true, false, "local").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(7)) + + user, err := store.AddUser("admin", "hash", true, true, true, false, "local") + require.NoError(t, err) + require.Equal(t, 7, user.ID) + + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source FROM users WHERE lower(username) = lower(?)", + )).WithArgs("admin").WillReturnRows(sqlmock.NewRows([]string{"id", "username", "password_hash", "is_staff", "can_approve", "local_login_enabled", "must_reset_password", "auth_source"}).AddRow(7, "admin", "hash", true, true, true, false, "local")) + + loaded, err := store.GetUserByUsername("admin") + require.NoError(t, err) + require.Equal(t, "admin", loaded.Username) + + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source FROM users WHERE id = ?", + )).WithArgs(7).WillReturnRows(sqlmock.NewRows([]string{"id", "username", "password_hash", "is_staff", "can_approve", "local_login_enabled", "must_reset_password", "auth_source"}).AddRow(7, "admin", "hash", true, true, true, false, "local")) + + byID, err := store.GetUserByID(7) + require.NoError(t, err) + require.Equal(t, "admin", byID.Username) + + mock.ExpectQuery(regexp.QuoteMeta( + "UPDATE users SET username = ?, is_staff = ?, can_approve = ?, local_login_enabled = ?, must_reset_password = ?, auth_source = ? WHERE id = ? RETURNING id, username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source", + )).WithArgs("updated", false, false, false, true, "local", 7).WillReturnRows(sqlmock.NewRows([]string{"id", "username", "password_hash", "is_staff", "can_approve", "local_login_enabled", "must_reset_password", "auth_source"}).AddRow(7, "updated", "hash", false, false, false, true, "local")) + + updated, err := store.UpdateUser(7, "updated", false, false, false, true, "local") + require.NoError(t, err) + require.Equal(t, "updated", updated.Username) + + mock.ExpectQuery(regexp.QuoteMeta( + "UPDATE users SET password_hash = ?, must_reset_password = ?, local_login_enabled = CASE WHEN ? IS NULL THEN local_login_enabled ELSE 1 END WHERE id = ? RETURNING id, username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source", + )).WithArgs(sqlmock.AnyArg(), false, sqlmock.AnyArg(), 7).WillReturnRows(sqlmock.NewRows([]string{"id", "username", "password_hash", "is_staff", "can_approve", "local_login_enabled", "must_reset_password", "auth_source"}).AddRow(7, "updated", "newhash", false, false, true, false, "local")) + + passwordUpdated, err := store.UpdateUserPassword(7, "newhash", false) + require.NoError(t, err) + require.Equal(t, "newhash", passwordUpdated.PasswordHash) + + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, username, password_hash, is_staff, can_approve, local_login_enabled, must_reset_password, auth_source FROM users ORDER BY id", + )).WillReturnRows(sqlmock.NewRows([]string{"id", "username", "password_hash", "is_staff", "can_approve", "local_login_enabled", "must_reset_password", "auth_source"}). + AddRow(7, "admin", "hash", true, true, true, false, "local"). + AddRow(8, "viewer", nil, false, false, false, false, "saml")) + + users, err := store.ListUsers() + require.NoError(t, err) + require.Len(t, users, 2) + + mock.ExpectExec(regexp.QuoteMeta( + "DELETE FROM users WHERE id = ?", + )).WithArgs(7).WillReturnResult(sqlmock.NewResult(0, 1)) + + require.NoError(t, store.DeleteUser(7)) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestSQLiteStoreCleanupRequests(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewSQLiteStoreWithDB(db, testCodec(t)) + cutoff := time.Now().Add(-7 * 24 * time.Hour) + mock.ExpectExec(regexp.QuoteMeta( + "UPDATE requests SET current = 0 WHERE current = 1 AND approved IS NOT NULL AND date_approved < ?", + )).WithArgs(cutoff).WillReturnResult(sqlmock.NewResult(0, 3)) + + updated, err := store.CleanupRequests(cutoff) + require.NoError(t, err) + require.Equal(t, 3, updated) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestSQLiteStoreSetSecretRotationRequired(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + codec := testCodec(t) + store := NewSQLiteStoreWithDB(db, codec) + mock.ExpectExec(regexp.QuoteMeta( + "UPDATE secrets SET rotation_required = ? WHERE id = ?", + )).WithArgs(true, 7).WillReturnResult(sqlmock.NewResult(0, 1)) + + encrypted, err := codec.Encrypt("secret") + require.NoError(t, err) + now := time.Now() + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, computer_id, secret_type, secret, date_escrowed, rotation_required FROM secrets WHERE id = ?", + )).WithArgs(7).WillReturnRows(sqlmock.NewRows([]string{"id", "computer_id", "secret_type", "secret", "date_escrowed", "rotation_required"}).AddRow(7, 2, "password", encrypted, now, 1)) + + updated, err := store.SetSecretRotationRequired(7, true) + require.NoError(t, err) + require.True(t, updated.RotationRequired) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestSQLiteStoreAuditEvents(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewSQLiteStoreWithDB(db, testCodec(t)) + now := time.Now() + mock.ExpectQuery(regexp.QuoteMeta( + "INSERT INTO audit_events (actor, target_user, action, reason, ip_address) VALUES (?, ?, ?, ?, ?) RETURNING id, created_at", + )).WithArgs("admin", "user", "password_reset", sqlmock.AnyArg(), sqlmock.AnyArg()). + WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(1, now)) + + event, err := store.AddAuditEvent("admin", "user", "password_reset", "", "") + require.NoError(t, err) + require.Equal(t, 1, event.ID) + + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, actor, target_user, action, reason, ip_address, created_at FROM audit_events ORDER BY created_at DESC, id DESC", + )).WillReturnRows(sqlmock.NewRows([]string{"id", "actor", "target_user", "action", "reason", "ip_address", "created_at"}). + AddRow(2, "admin", "user", "force_reset_enabled", nil, nil, now)) + + events, err := store.ListAuditEvents() + require.NoError(t, err) + require.Len(t, events, 1) + require.Equal(t, "force_reset_enabled", events[0].Action) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestSQLiteStoreSearchAuditEvents(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewSQLiteStoreWithDB(db, testCodec(t)) + now := time.Now() + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, actor, target_user, action, reason, ip_address, created_at FROM audit_events WHERE lower(actor) LIKE lower(?) OR lower(target_user) LIKE lower(?) OR lower(action) LIKE lower(?) OR lower(COALESCE(reason, '')) LIKE lower(?) OR lower(COALESCE(ip_address, '')) LIKE lower(?) ORDER BY created_at DESC, id DESC", + )).WithArgs("%reset%", "%reset%", "%reset%", "%reset%", "%reset%"). + WillReturnRows(sqlmock.NewRows([]string{"id", "actor", "target_user", "action", "reason", "ip_address", "created_at"}). + AddRow(3, "admin", "user", "password_reset", "reason", "127.0.0.1", now)) + + events, err := store.SearchAuditEvents("reset") + require.NoError(t, err) + require.Len(t, events, 1) + require.Equal(t, "password_reset", events[0].Action) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestSQLiteStoreAuditEventsPagingAndCount(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewSQLiteStoreWithDB(db, testCodec(t)) + now := time.Now() + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT COUNT(*) FROM audit_events", + )).WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(120)) + + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT id, actor, target_user, action, reason, ip_address, created_at FROM audit_events ORDER BY created_at DESC, id DESC LIMIT ? OFFSET ?", + )).WithArgs(50, 50). + WillReturnRows(sqlmock.NewRows([]string{"id", "actor", "target_user", "action", "reason", "ip_address", "created_at"}). + AddRow(10, "admin", "user", "password_reset", nil, nil, now)) + + count, err := store.CountAuditEvents() + require.NoError(t, err) + require.Equal(t, 120, count) + + events, err := store.ListAuditEventsPaged(50, 50) + require.NoError(t, err) + require.Len(t, events, 1) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestSQLiteStoreCountSearchAuditEvents(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + store := NewSQLiteStoreWithDB(db, testCodec(t)) + mock.ExpectQuery(regexp.QuoteMeta( + "SELECT COUNT(*) FROM audit_events WHERE lower(actor) LIKE lower(?) OR lower(target_user) LIKE lower(?) OR lower(action) LIKE lower(?) OR lower(COALESCE(reason, '')) LIKE lower(?) OR lower(COALESCE(ip_address, '')) LIKE lower(?)", + )).WithArgs("%reset%", "%reset%", "%reset%", "%reset%", "%reset%"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(5)) + + count, err := store.CountSearchAuditEvents("reset") + require.NoError(t, err) + require.Equal(t, 5, count) + require.NoError(t, mock.ExpectationsWereMet()) +} diff --git a/internal/store/store.go b/internal/store/store.go new file mode 100644 index 0000000..23cc4a3 --- /dev/null +++ b/internal/store/store.go @@ -0,0 +1,60 @@ +package store + +import ( + "errors" + "time" +) + +var ErrNotFound = errors.New("not found") +var ErrMissingCodec = errors.New("secret codec is required") + +type SecretCodec interface { + Encrypt(plaintext string) (string, error) + Decrypt(ciphertext string) (string, error) +} + +type Store interface { + AddComputer(serial, username, computerName string) (*Computer, error) + UpsertComputer(serial, username, computerName string, lastCheckin time.Time) (*Computer, error) + ListComputers() ([]*Computer, error) + GetComputerByID(id int) (*Computer, error) + GetComputerBySerial(serial string) (*Computer, error) + // AddSecret adds a new secret. Returns the secret, a bool indicating if it was newly created + // (false if the same secret value already exists), and any error. + AddSecret(computerID int, secretType, secret string, rotationRequired bool) (*Secret, bool, error) + ListSecretsByComputer(computerID int) ([]*Secret, error) + GetSecretByID(id int) (*Secret, error) + GetLatestSecretByComputerAndType(computerID int, secretType string) (*Secret, error) + AddRequest(secretID int, requestingUser, reason string, approvedBy string, approved *bool) (*Request, error) + ListRequestsBySecret(secretID int) ([]*Request, error) + ListOutstandingRequests() ([]*Request, error) + GetRequestByID(id int) (*Request, error) + ApproveRequest(requestID int, approved bool, reason, approver string) (*Request, error) + AddUser(username, passwordHash string, isStaff, canApprove, localLoginEnabled, mustResetPassword bool, authSource string) (*User, error) + GetUserByUsername(username string) (*User, error) + ListUsers() ([]*User, error) + GetUserByID(id int) (*User, error) + UpdateUser(id int, username string, isStaff, canApprove, localLoginEnabled, mustResetPassword bool, authSource string) (*User, error) + UpdateUserPassword(id int, passwordHash string, mustResetPassword bool) (*User, error) + DeleteUser(id int) error + CleanupRequests(approvedBefore time.Time) (int, error) + SetSecretRotationRequired(secretID int, rotationRequired bool) (*Secret, error) + AddAuditEvent(actor, targetUser, action, reason, ipAddress string) (*AuditEvent, error) + ListAuditEvents() ([]*AuditEvent, error) + SearchAuditEvents(query string) ([]*AuditEvent, error) + ListAuditEventsPaged(limit, offset int) ([]*AuditEvent, error) + SearchAuditEventsPaged(query string, limit, offset int) ([]*AuditEvent, error) + CountAuditEvents() (int, error) + CountSearchAuditEvents(query string) (int, error) + // IsEmpty returns true if all data tables are empty (no rows). + // This is used to check if it's safe to import fixture data. + IsEmpty() (bool, error) + // ImportComputer inserts a computer with a specific ID. + ImportComputer(id int, serial, username, computerName string, lastCheckin time.Time) error + // ImportSecret inserts a secret with a specific ID. The secret is already encrypted. + ImportSecret(id, computerID int, secretType, encryptedSecret string, dateEscrowed time.Time, rotationRequired bool) error + // ImportRequest inserts a request with a specific ID. + ImportRequest(id, secretID int, requestingUser string, approved *bool, authUser, reasonForRequest, reasonForApproval string, dateRequested time.Time, dateApproved *time.Time, current bool) error + // ImportUser inserts a user with a specific ID. + ImportUser(id int, username, passwordHash string, isStaff, canApprove, localLoginEnabled, mustResetPassword bool, authSource string) error +} diff --git a/internal/store/test_helpers_test.go b/internal/store/test_helpers_test.go new file mode 100644 index 0000000..7896f99 --- /dev/null +++ b/internal/store/test_helpers_test.go @@ -0,0 +1,21 @@ +package store + +import ( + "encoding/base64" + "testing" + + "crypt-server/internal/crypto" + "github.com/stretchr/testify/require" +) + +func testCodec(t *testing.T) *crypto.AesGcmCodec { + t.Helper() + key := make([]byte, 32) + for i := range key { + key[i] = byte(i + 1) + } + encoded := base64.StdEncoding.EncodeToString(key) + codec, err := crypto.NewAesGcmCodecFromBase64Key(encoded) + require.NoError(t, err) + return codec +} diff --git a/manage.py b/manage.py deleted file mode 100644 index b951459..0000000 --- a/manage.py +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env python -import os -import sys - -if __name__ == "__main__": - os.environ.setdefault("DJANGO_SETTINGS_MODULE", "fvserver.settings") - - from django.core.management import execute_from_command_line - - execute_from_command_line(sys.argv) diff --git a/remote_build.py b/remote_build.py deleted file mode 100644 index ae10abe..0000000 --- a/remote_build.py +++ /dev/null @@ -1,28 +0,0 @@ -#!/usr/bin/env python - -import subprocess -import requests -import os -import argparse - -parser = argparse.ArgumentParser(description="Process a build.") -parser.add_argument("build_tag", type=str, help="The tag to build.") - -args = parser.parse_args() - -api_user_token = os.getenv("CIRCLE_API_USER_TOKEN") -project_reponame = "crypt-server-saml" -project_username = "grahamgilbert" - -post_data = {} -post_data["build_parameters"] = {"TAG": args.build_tag} - -url = "https://circleci.com/api/v1.1/project/github/{}/{}/tree/master".format( - project_username, project_reponame -) - -the_request = requests.post(url, json=post_data, auth=(api_user_token, "")) -if the_request.status_code == requests.codes.ok: - print(the_request.json) -else: - print(the_request.text) diff --git a/server/__init__.py b/server/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/server/admin.py b/server/admin.py deleted file mode 100755 index 96e25ed..0000000 --- a/server/admin.py +++ /dev/null @@ -1,6 +0,0 @@ -from django.contrib import admin -from server.models import * - -admin.site.register(Computer) -admin.site.register(Secret) -admin.site.register(Request) diff --git a/server/forms.py b/server/forms.py deleted file mode 100755 index a08fcc4..0000000 --- a/server/forms.py +++ /dev/null @@ -1,35 +0,0 @@ -from django import forms -from .models import * - - -class RequestForm(forms.ModelForm): - class Meta: - model = Request - fields = ("reason_for_request",) - - -class ApproveForm(forms.ModelForm): - # approved = forms.BooleanField() - approved = forms.TypedChoiceField( - coerce=lambda x: bool(int(x)), - choices=((1, "Approved"), (0, "Denied")), - widget=forms.RadioSelect, - label="Approved?", - ) - - class Meta: - model = Request - fields = ("approved", "reason_for_approval") - - -class ComputerForm(forms.ModelForm): - class Meta: - model = Computer - fields = ("serial", "username", "computername") - - -class SecretForm(forms.ModelForm): - class Meta: - model = Secret - fields = ("secret_type", "secret", "computer") - widgets = {"computer": forms.HiddenInput()} diff --git a/server/migrations/0001_initial.py b/server/migrations/0001_initial.py deleted file mode 100644 index df51e30..0000000 --- a/server/migrations/0001_initial.py +++ /dev/null @@ -1,96 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals - -from django.db import models, migrations -from django.conf import settings - - -class Migration(migrations.Migration): - - dependencies = [migrations.swappable_dependency(settings.AUTH_USER_MODEL)] - - operations = [ - migrations.CreateModel( - name="Computer", - fields=[ - ( - "id", - models.AutoField( - verbose_name="ID", - serialize=False, - auto_created=True, - primary_key=True, - ), - ), - ( - "recovery_key", - models.CharField(max_length=200, verbose_name=b"Recovery Key"), - ), - ( - "serial", - models.CharField(max_length=200, verbose_name=b"Serial Number"), - ), - ( - "username", - models.CharField(max_length=200, verbose_name=b"User Name"), - ), - ( - "computername", - models.CharField(max_length=200, verbose_name=b"Computer Name"), - ), - ("last_checkin", models.DateTimeField(null=True, blank=True)), - ], - options={ - "ordering": ["serial"], - "permissions": ( - ("can_approve", "Can approve requests to see encryption keys"), - ), - }, - ), - migrations.CreateModel( - name="Request", - fields=[ - ( - "id", - models.AutoField( - verbose_name="ID", - serialize=False, - auto_created=True, - primary_key=True, - ), - ), - ("approved", models.NullBooleanField(verbose_name=b"Approved?")), - ("reason_for_request", models.TextField()), - ( - "reason_for_approval", - models.TextField( - null=True, verbose_name=b"Approval Notes", blank=True - ), - ), - ("date_requested", models.DateTimeField(auto_now_add=True)), - ("date_approved", models.DateTimeField(null=True, blank=True)), - ("current", models.BooleanField(default=True)), - ( - "auth_user", - models.ForeignKey( - related_name="auth_user", - to=settings.AUTH_USER_MODEL, - null=True, - on_delete=models.CASCADE, - ), - ), - ( - "computer", - models.ForeignKey(to="server.Computer", on_delete=models.CASCADE), - ), - ( - "requesting_user", - models.ForeignKey( - related_name="requesting_user", - to=settings.AUTH_USER_MODEL, - on_delete=models.CASCADE, - ), - ), - ], - ), - ] diff --git a/server/migrations/0001_squashed_0017_merge_20181217_1829.py b/server/migrations/0001_squashed_0017_merge_20181217_1829.py deleted file mode 100644 index cd1addf..0000000 --- a/server/migrations/0001_squashed_0017_merge_20181217_1829.py +++ /dev/null @@ -1,333 +0,0 @@ -# Generated by Django 2.2.27 on 2022-03-28 14:50 - -from django.conf import settings -from django.db import migrations, models -import django.db.models.deletion -from django.shortcuts import get_object_or_404 - - -# Functions from the following migrations need manual copying. -# Move them and any dependencies into this file, then update the -# RunPython operations to refer to the local versions: -# server.migrations.0003_auto_20150713_1215 -# server.migrations.0007_auto_20150714_0822 -# server.migrations.0010_auto_20180726_1700 - - -def move_keys_and_requests(apps, schema_editor): - seen_serials = [("dummy_serial", "dummy_id")] - Computer = apps.get_model("server", "Computer") - Secret = apps.get_model("server", "Secret") - Request = apps.get_model("server", "Request") - for computer in Computer.objects.all(): - # if we've seen the serial before, get the computer that we saw before - target_id = None - for serial, id in seen_serials: - if computer.serial == serial: - target_id = id - break - if target_id == None: - target_id = computer.id - - target_computer = get_object_or_404(Computer, pk=target_id) - # create a new secret - secret = Secret( - computer=target_computer, - secret=computer.recovery_key, - date_escrowed=computer.last_checkin, - ) - secret.save() - - requests = Request.objects.filter(computer=computer) - for request in requests: - request.secret = secret - request.save() - - if target_computer.id != computer.id: - # Dupe computer, bin it - computer.delete() - - -class Migration(migrations.Migration): - - replaces = [ - ("server", "0001_initial"), - ("server", "0002_auto_20150713_1214"), - ("server", "0003_auto_20150713_1215"), - ("server", "0004_auto_20150713_1216"), - ("server", "0005_auto_20150713_1754"), - ("server", "0006_auto_20150714_0821"), - ("server", "0007_auto_20150714_0822"), - ("server", "0008_auto_20150814_2140"), - ("server", "0009_secret_rotation_required"), - ("server", "0010_auto_20180726_1700"), - ("server", "0011_manual_unique_serials"), - ("server", "0012_auto_20181128_2038"), - ("server", "0016_auto_20181213_2145"), - ("server", "0009_auto_20180430_2024"), - ("server", "0017_merge_20181217_1829"), - ] - - initial = True - - dependencies = [ - migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ] - - operations = [ - migrations.CreateModel( - name="Computer", - fields=[ - ( - "id", - models.AutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), - ( - "recovery_key", - models.CharField(max_length=200, verbose_name=b"Recovery Key"), - ), - ( - "serial", - models.CharField(max_length=200, verbose_name=b"Serial Number"), - ), - ( - "username", - models.CharField(max_length=200, verbose_name=b"User Name"), - ), - ( - "computername", - models.CharField(max_length=200, verbose_name=b"Computer Name"), - ), - ("last_checkin", models.DateTimeField(blank=True, null=True)), - ], - options={ - "ordering": ["serial"], - "permissions": ( - ("can_approve", "Can approve requests to see encryption keys"), - ), - }, - ), - migrations.CreateModel( - name="Secret", - fields=[ - ( - "id", - models.AutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), - ("secret", models.CharField(max_length=256)), - ( - "secret_type", - models.CharField( - choices=[ - (b"recovery_key", b"Recovery Key"), - (b"password", b"Password"), - ], - default=b"recovery_key", - max_length=256, - ), - ), - ("date_escrowed", models.DateTimeField(auto_now_add=True)), - ( - "computer", - models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, - to="server.Computer", - ), - ), - ], - ), - migrations.CreateModel( - name="Request", - fields=[ - ( - "id", - models.AutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), - ("approved", models.NullBooleanField(verbose_name=b"Approved?")), - ("reason_for_request", models.TextField()), - ( - "reason_for_approval", - models.TextField( - blank=True, null=True, verbose_name=b"Approval Notes" - ), - ), - ("date_requested", models.DateTimeField(auto_now_add=True)), - ("date_approved", models.DateTimeField(blank=True, null=True)), - ("current", models.BooleanField(default=True)), - ( - "auth_user", - models.ForeignKey( - null=True, - on_delete=django.db.models.deletion.CASCADE, - related_name="auth_user", - to=settings.AUTH_USER_MODEL, - ), - ), - ( - "computer", - models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, - related_name="computers", - to="server.Computer", - ), - ), - ( - "requesting_user", - models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, - related_name="requesting_user", - to=settings.AUTH_USER_MODEL, - ), - ), - ( - "secret", - models.ForeignKey( - null=True, - on_delete=django.db.models.deletion.CASCADE, - related_name="secrets", - to="server.Secret", - ), - ), - ], - ), - migrations.RunPython( - code=move_keys_and_requests, - ), - migrations.RemoveField( - model_name="computer", - name="recovery_key", - ), - migrations.RemoveField( - model_name="request", - name="computer", - ), - migrations.AlterField( - model_name="request", - name="secret", - field=models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, to="server.Secret" - ), - ), - migrations.AlterModelOptions( - name="secret", - options={"ordering": ["-date_escrowed"]}, - ), - migrations.AddField( - model_name="secret", - name="rotation_required", - field=models.BooleanField(default=False), - ), - migrations.AlterField( - model_name="computer", - name="serial", - field=models.CharField( - max_length=200, unique=True, verbose_name=b"Serial Number" - ), - ), - migrations.AlterField( - model_name="computer", - name="computername", - field=models.CharField(max_length=200, verbose_name="Computer Name"), - ), - migrations.AlterField( - model_name="computer", - name="serial", - field=models.CharField( - max_length=200, unique=True, verbose_name="Serial Number" - ), - ), - migrations.AlterField( - model_name="computer", - name="username", - field=models.CharField(max_length=200, verbose_name="User Name"), - ), - migrations.AlterField( - model_name="request", - name="approved", - field=models.NullBooleanField(verbose_name="Approved?"), - ), - migrations.AlterField( - model_name="request", - name="reason_for_approval", - field=models.TextField( - blank=True, null=True, verbose_name="Approval Notes" - ), - ), - migrations.AlterField( - model_name="secret", - name="secret_type", - field=models.CharField( - choices=[("recovery_key", "Recovery Key"), ("password", "Password")], - default="recovery_key", - max_length=256, - ), - ), - migrations.AlterField( - model_name="request", - name="auth_user", - field=models.ForeignKey( - null=True, - on_delete=django.db.models.deletion.PROTECT, - related_name="auth_user", - to=settings.AUTH_USER_MODEL, - ), - ), - migrations.AlterField( - model_name="request", - name="secret", - field=models.ForeignKey( - on_delete=django.db.models.deletion.PROTECT, to="server.Secret" - ), - ), - migrations.AlterField( - model_name="computer", - name="computername", - field=models.CharField(max_length=200, verbose_name="Computer Name"), - ), - migrations.AlterField( - model_name="computer", - name="serial", - field=models.CharField(max_length=200, verbose_name="Serial Number"), - ), - migrations.AlterField( - model_name="computer", - name="username", - field=models.CharField(max_length=200, verbose_name="User Name"), - ), - migrations.AlterField( - model_name="request", - name="approved", - field=models.NullBooleanField(verbose_name="Approved?"), - ), - migrations.AlterField( - model_name="request", - name="reason_for_approval", - field=models.TextField( - blank=True, null=True, verbose_name="Approval Notes" - ), - ), - migrations.AlterField( - model_name="secret", - name="secret_type", - field=models.CharField( - choices=[("recovery_key", "Recovery Key"), ("password", "Password")], - default="recovery_key", - max_length=256, - ), - ), - ] diff --git a/server/migrations/0002_auto_20150713_1214.py b/server/migrations/0002_auto_20150713_1214.py deleted file mode 100644 index 9e9e22a..0000000 --- a/server/migrations/0002_auto_20150713_1214.py +++ /dev/null @@ -1,62 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals -from django.shortcuts import get_object_or_404 -from django.db import models, migrations - - -class Migration(migrations.Migration): - - dependencies = [("server", "0001_initial")] - - operations = [ - migrations.CreateModel( - name="Secret", - fields=[ - ( - "id", - models.AutoField( - verbose_name="ID", - serialize=False, - auto_created=True, - primary_key=True, - ), - ), - ("secret", models.CharField(max_length=256)), - ( - "secret_type", - models.CharField( - default=b"recovery_key", - max_length=256, - choices=[ - (b"recovery_key", b"Recovery Key"), - (b"password", b"Password"), - ], - ), - ), - ("date_escrowed", models.DateTimeField(auto_now_add=True)), - ], - ), - migrations.AddField( - model_name="secret", - name="computer", - field=models.ForeignKey(to="server.Computer", on_delete=models.CASCADE), - ), - migrations.AlterField( - model_name="request", - name="computer", - field=models.ForeignKey( - related_name="computers", to="server.Computer", on_delete=models.CASCADE - ), - ), - migrations.AddField( - model_name="request", - name="secret", - field=models.ForeignKey( - null=True, - related_name="secrets", - to="server.Secret", - on_delete=models.CASCADE, - ), - preserve_default=False, - ), - ] diff --git a/server/migrations/0003_auto_20150713_1215.py b/server/migrations/0003_auto_20150713_1215.py deleted file mode 100644 index 2254ea2..0000000 --- a/server/migrations/0003_auto_20150713_1215.py +++ /dev/null @@ -1,45 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals -from django.shortcuts import get_object_or_404 -from django.db import models, migrations - - -def move_keys_and_requests(apps, schema_editor): - seen_serials = [("dummy_serial", "dummy_id")] - Computer = apps.get_model("server", "Computer") - Secret = apps.get_model("server", "Secret") - Request = apps.get_model("server", "Request") - for computer in Computer.objects.all(): - # if we've seen the serial before, get the computer that we saw before - target_id = None - for serial, id in seen_serials: - if computer.serial == serial: - target_id = id - break - if target_id == None: - target_id = computer.id - - target_computer = get_object_or_404(Computer, pk=target_id) - # create a new secret - secret = Secret( - computer=target_computer, - secret=computer.recovery_key, - date_escrowed=computer.last_checkin, - ) - secret.save() - - requests = Request.objects.filter(computer=computer) - for request in requests: - request.secret = secret - request.save() - - if target_computer.id != computer.id: - # Dupe computer, bin it - computer.delete() - - -class Migration(migrations.Migration): - - dependencies = [("server", "0002_auto_20150713_1214")] - - operations = [migrations.RunPython(move_keys_and_requests)] diff --git a/server/migrations/0004_auto_20150713_1216.py b/server/migrations/0004_auto_20150713_1216.py deleted file mode 100644 index 7c92513..0000000 --- a/server/migrations/0004_auto_20150713_1216.py +++ /dev/null @@ -1,14 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals -from django.shortcuts import get_object_or_404 -from django.db import models, migrations - - -class Migration(migrations.Migration): - - dependencies = [("server", "0003_auto_20150713_1215")] - - operations = [ - migrations.RemoveField(model_name="computer", name="recovery_key"), - migrations.RemoveField(model_name="request", name="computer"), - ] diff --git a/server/migrations/0005_auto_20150713_1754.py b/server/migrations/0005_auto_20150713_1754.py deleted file mode 100644 index 17f5422..0000000 --- a/server/migrations/0005_auto_20150713_1754.py +++ /dev/null @@ -1,17 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals - -from django.db import models, migrations - - -class Migration(migrations.Migration): - - dependencies = [("server", "0004_auto_20150713_1216")] - - operations = [ - migrations.AlterField( - model_name="request", - name="secret", - field=models.ForeignKey(to="server.Secret", on_delete=models.CASCADE), - ) - ] diff --git a/server/migrations/0006_auto_20150714_0821.py b/server/migrations/0006_auto_20150714_0821.py deleted file mode 100644 index bf51afa..0000000 --- a/server/migrations/0006_auto_20150714_0821.py +++ /dev/null @@ -1,21 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals - -from django.db import models, migrations - -# import django_extensions.db.fields.encrypted - - -class Migration(migrations.Migration): - - dependencies = [("server", "0005_auto_20150713_1754")] - - operations = [ - # migrations.AlterField( - # model_name="secret", - # name="secret", - # field=django_extensions.db.fields.encrypted.EncryptedCharField( - # max_length=256 - # ), - # ) - ] diff --git a/server/migrations/0007_auto_20150714_0822.py b/server/migrations/0007_auto_20150714_0822.py deleted file mode 100644 index 690806b..0000000 --- a/server/migrations/0007_auto_20150714_0822.py +++ /dev/null @@ -1,19 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals -from django.shortcuts import get_object_or_404 -from django.db import models, migrations - - -def encrypt_secrets(apps, schema_editor): - - Secret = apps.get_model("server", "Secret") - - for secret in Secret.objects.all(): - secret.save() - - -class Migration(migrations.Migration): - - dependencies = [("server", "0006_auto_20150714_0821")] - - operations = [migrations.RunPython(encrypt_secrets)] diff --git a/server/migrations/0008_auto_20150814_2140.py b/server/migrations/0008_auto_20150814_2140.py deleted file mode 100644 index 2691a00..0000000 --- a/server/migrations/0008_auto_20150814_2140.py +++ /dev/null @@ -1,15 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals - -from django.db import models, migrations - - -class Migration(migrations.Migration): - - dependencies = [("server", "0007_auto_20150714_0822")] - - operations = [ - migrations.AlterModelOptions( - name="secret", options={"ordering": ["-date_escrowed"]} - ) - ] diff --git a/server/migrations/0009_auto_20180430_2024.py b/server/migrations/0009_auto_20180430_2024.py deleted file mode 100644 index 5119fc4..0000000 --- a/server/migrations/0009_auto_20180430_2024.py +++ /dev/null @@ -1,47 +0,0 @@ -# Generated by Django 2.0.4 on 2018-04-30 19:24 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [("server", "0008_auto_20150814_2140")] - - operations = [ - migrations.AlterField( - model_name="computer", - name="computername", - field=models.CharField(max_length=200, verbose_name="Computer Name"), - ), - migrations.AlterField( - model_name="computer", - name="serial", - field=models.CharField(max_length=200, verbose_name="Serial Number"), - ), - migrations.AlterField( - model_name="computer", - name="username", - field=models.CharField(max_length=200, verbose_name="User Name"), - ), - migrations.AlterField( - model_name="request", - name="approved", - field=models.NullBooleanField(verbose_name="Approved?"), - ), - migrations.AlterField( - model_name="request", - name="reason_for_approval", - field=models.TextField( - blank=True, null=True, verbose_name="Approval Notes" - ), - ), - migrations.AlterField( - model_name="secret", - name="secret_type", - field=models.CharField( - choices=[("recovery_key", "Recovery Key"), ("password", "Password")], - default="recovery_key", - max_length=256, - ), - ), - ] diff --git a/server/migrations/0009_secret_rotation_required.py b/server/migrations/0009_secret_rotation_required.py deleted file mode 100644 index 408ddc6..0000000 --- a/server/migrations/0009_secret_rotation_required.py +++ /dev/null @@ -1,18 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by Django 1.10 on 2018-04-30 21:37 -from __future__ import unicode_literals - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [("server", "0008_auto_20150814_2140")] - - operations = [ - migrations.AddField( - model_name="secret", - name="rotation_required", - field=models.BooleanField(default=False), - ) - ] diff --git a/server/migrations/0010_auto_20180726_1700.py b/server/migrations/0010_auto_20180726_1700.py deleted file mode 100644 index 7a29c31..0000000 --- a/server/migrations/0010_auto_20180726_1700.py +++ /dev/null @@ -1,37 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by Django 1.10 on 2018-07-26 16:00 -from __future__ import unicode_literals - -from server.models import * -from django.db import migrations, models - - -def unique_serials(apps, schema_editor): - """ - Make sure serial numbers are unique - """ - seen_serials = [] - Computer = apps.get_model("server", "Computer") - Secret = apps.get_model("server", "Secret") - all_computers = Computer.objects.all() - for computer in all_computers: - if computer.serial not in seen_serials: - # not seen it before, add it to the list of devices we've seen - seen_serials.append(computer.serial) - else: - # we've seen it before, select all the secrets for the - # machine and move them to the first instance of the serial number - secrets = Secret.objects.filter(computer=computer) - # reselect here so we don't get bit when we delete the computer - first_computer = Computer.objects.all().first() - for secret in secrets: - secret.computer = first_computer - secret.save() - computer.delete() - - -class Migration(migrations.Migration): - - dependencies = [("server", "0009_secret_rotation_required")] - - operations = [migrations.RunPython(unique_serials)] diff --git a/server/migrations/0011_manual_unique_serials.py b/server/migrations/0011_manual_unique_serials.py deleted file mode 100644 index 6ab0ada..0000000 --- a/server/migrations/0011_manual_unique_serials.py +++ /dev/null @@ -1,21 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by Django 1.10 on 2018-07-26 16:00 -from __future__ import unicode_literals - -from server.models import * -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [("server", "0010_auto_20180726_1700")] - - operations = [ - migrations.AlterField( - model_name="computer", - name="serial", - field=models.CharField( - max_length=200, unique=True, verbose_name=b"Serial Number" - ), - ) - ] diff --git a/server/migrations/0012_auto_20181128_2038.py b/server/migrations/0012_auto_20181128_2038.py deleted file mode 100644 index a85c8dc..0000000 --- a/server/migrations/0012_auto_20181128_2038.py +++ /dev/null @@ -1,51 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by Django 1.10 on 2018-11-28 20:38 -from __future__ import unicode_literals - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [("server", "0011_manual_unique_serials")] - - operations = [ - migrations.AlterField( - model_name="computer", - name="computername", - field=models.CharField(max_length=200, verbose_name="Computer Name"), - ), - migrations.AlterField( - model_name="computer", - name="serial", - field=models.CharField( - max_length=200, unique=True, verbose_name="Serial Number" - ), - ), - migrations.AlterField( - model_name="computer", - name="username", - field=models.CharField(max_length=200, verbose_name="User Name"), - ), - migrations.AlterField( - model_name="request", - name="approved", - field=models.NullBooleanField(verbose_name="Approved?"), - ), - migrations.AlterField( - model_name="request", - name="reason_for_approval", - field=models.TextField( - blank=True, null=True, verbose_name="Approval Notes" - ), - ), - migrations.AlterField( - model_name="secret", - name="secret_type", - field=models.CharField( - choices=[("recovery_key", "Recovery Key"), ("password", "Password")], - default="recovery_key", - max_length=256, - ), - ), - ] diff --git a/server/migrations/0016_auto_20181213_2145.py b/server/migrations/0016_auto_20181213_2145.py deleted file mode 100644 index 3a5f9b9..0000000 --- a/server/migrations/0016_auto_20181213_2145.py +++ /dev/null @@ -1,30 +0,0 @@ -# Generated by Django 2.1.4 on 2018-12-13 21:45 - -from django.conf import settings -from django.db import migrations, models -import django.db.models.deletion - - -class Migration(migrations.Migration): - - dependencies = [("server", "0012_auto_20181128_2038")] - - operations = [ - migrations.AlterField( - model_name="request", - name="auth_user", - field=models.ForeignKey( - null=True, - on_delete=django.db.models.deletion.PROTECT, - related_name="auth_user", - to=settings.AUTH_USER_MODEL, - ), - ), - migrations.AlterField( - model_name="request", - name="secret", - field=models.ForeignKey( - on_delete=django.db.models.deletion.PROTECT, to="server.Secret" - ), - ), - ] diff --git a/server/migrations/0017_merge_20181217_1829.py b/server/migrations/0017_merge_20181217_1829.py deleted file mode 100644 index 371eec8..0000000 --- a/server/migrations/0017_merge_20181217_1829.py +++ /dev/null @@ -1,13 +0,0 @@ -# Generated by Django 2.1.4 on 2018-12-17 18:29 - -from django.db import migrations - - -class Migration(migrations.Migration): - - dependencies = [ - ("server", "0009_auto_20180430_2024"), - ("server", "0016_auto_20181213_2145"), - ] - - operations = [] diff --git a/server/migrations/0018_auto_20201029_2134.py b/server/migrations/0018_auto_20201029_2134.py deleted file mode 100644 index 9bd5f81..0000000 --- a/server/migrations/0018_auto_20201029_2134.py +++ /dev/null @@ -1,31 +0,0 @@ -# Generated by Django 2.2.13 on 2020-10-29 21:34 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [("server", "0017_merge_20181217_1829")] - - operations = [ - migrations.AlterField( - model_name="computer", - name="serial", - field=models.CharField( - max_length=200, unique=True, verbose_name="Serial Number" - ), - ), - migrations.AlterField( - model_name="secret", - name="secret_type", - field=models.CharField( - choices=[ - ("recovery_key", "Recovery Key"), - ("password", "Password"), - ("unlock_pin", "Unlock PIN"), - ], - default="recovery_key", - max_length=256, - ), - ), - ] diff --git a/server/migrations/0019_alter_request_approved_alter_secret_secret.py b/server/migrations/0019_alter_request_approved_alter_secret_secret.py deleted file mode 100644 index 02b36af..0000000 --- a/server/migrations/0019_alter_request_approved_alter_secret_secret.py +++ /dev/null @@ -1,24 +0,0 @@ -# Generated by Django 4.1.2 on 2022-11-10 17:28 - -from django.db import migrations, models -import encrypted_model_fields.fields - - -class Migration(migrations.Migration): - - dependencies = [ - ("server", "0018_auto_20201029_2134"), - ] - - operations = [ - migrations.AlterField( - model_name="request", - name="approved", - field=models.BooleanField(null=True, verbose_name="Approved?"), - ), - migrations.AlterField( - model_name="secret", - name="secret", - field=encrypted_model_fields.fields.EncryptedCharField(), - ), - ] diff --git a/server/migrations/__init__.py b/server/migrations/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/server/models.py b/server/models.py deleted file mode 100644 index aec40ea..0000000 --- a/server/models.py +++ /dev/null @@ -1,87 +0,0 @@ -from django.db import models -from django.contrib.auth.models import User, Permission -from django.contrib.contenttypes.models import ContentType -from encrypted_model_fields.fields import EncryptedCharField - -from django.core.exceptions import ValidationError - - -# Create your models here. -class Computer(models.Model): - # recovery_key = models.CharField(max_length=200, verbose_name="Recovery Key") - serial = models.CharField(max_length=200, verbose_name="Serial Number", unique=True) - username = models.CharField(max_length=200, verbose_name="User Name") - computername = models.CharField(max_length=200, verbose_name="Computer Name") - last_checkin = models.DateTimeField(blank=True, null=True) - - def __str__(self): - return self.computername - - class Meta: - ordering = ["serial"] - permissions = ( - ("can_approve", ("Can approve requests to see encryption keys")), - ) - - -SECRET_TYPES = ( - ("recovery_key", "Recovery Key"), - ("password", "Password"), - ("unlock_pin", "Unlock PIN"), -) - - -class Secret(models.Model): - computer = models.ForeignKey(Computer, on_delete=models.CASCADE) - secret = EncryptedCharField(max_length=256) - secret_type = models.CharField( - max_length=256, choices=SECRET_TYPES, default="recovery_key" - ) - date_escrowed = models.DateTimeField(auto_now_add=True) - rotation_required = models.BooleanField(default=False) - - def validate_unique(self, *args, **kwargs): - if ( - self.secret - in [ - str(s) - for s in self.__class__.objects.filter( - secret_type=self.secret_type, computer=self.computer - ) - ] - and not self.rotation_required - ): - raise ValidationError("already used") - super(Secret, self).validate_unique(*args, **kwargs) - - def save(self, *args, **kwargs): - self.validate_unique() - super(Secret, self).save(*args, **kwargs) - - def __str__(self): - return self.secret - - class Meta: - ordering = ["-date_escrowed"] - - -class Request(models.Model): - secret = models.ForeignKey(Secret, on_delete=models.PROTECT) - # computer = models.ForeignKey(Computer, null=True, related_name='computers') - requesting_user = models.ForeignKey( - User, related_name="requesting_user", on_delete=models.CASCADE - ) - approved = models.BooleanField(verbose_name="Approved?", null=True) - auth_user = models.ForeignKey( - User, null=True, related_name="auth_user", on_delete=models.PROTECT - ) - reason_for_request = models.TextField() - reason_for_approval = models.TextField( - blank=True, null=True, verbose_name="Approval Notes" - ) - date_requested = models.DateTimeField(auto_now_add=True) - date_approved = models.DateTimeField(blank=True, null=True) - current = models.BooleanField(default=True) - - def __str__(self): - return "%s - %s" % (self.secret, self.requesting_user) diff --git a/server/templates/server/approve.html b/server/templates/server/approve.html deleted file mode 100644 index 3111cd2..0000000 --- a/server/templates/server/approve.html +++ /dev/null @@ -1,16 +0,0 @@ -{% extends "base.html" %} - -{% load bootstrap4 %} -{% block content %} -

Approve Request

-

{{ the_request.secret.computer.computername }} ({{ the_request.secret.computer.serial }})

-{% if error_message %}

{{ error_message }}

{% endif %} -
{% csrf_token %} - - - {% bootstrap_form form %} - -

- -
-{% endblock %} diff --git a/server/templates/server/computer_info.html b/server/templates/server/computer_info.html deleted file mode 100755 index 91b70fa..0000000 --- a/server/templates/server/computer_info.html +++ /dev/null @@ -1,95 +0,0 @@ -{% extends "base.html" %} {% block script %} - - -{% endblock %} {% block dropdown %} - -{% endblock %} {% block nav %} - -{% endblock %} {% block content %} -
-
-

{{ computer.computername }}

-

{{ computer.serial }}

-
-
- -
-
- - - - - - - - - - - - - - - - - - - -
Username{{ computer.username }}
Computer Name{{ computer.computername }}
Serial Number{{ computer.serial }}
Last Checked In{{ computer.last_checkin}}
-
-
- {% block button %} {% endblock %} -
-
- -
-
-

Secrets

- - - - - - - - {% for secret in secrets %} - - - - - - - {% endfor %} - -
Secret TypeEscrow Date
{{ secret.get_secret_type_display }}{{ secret.date_escrowed }} - Info / Request -
-
-
- -{% endblock %} diff --git a/server/templates/server/index.html b/server/templates/server/index.html deleted file mode 100644 index 3e28878..0000000 --- a/server/templates/server/index.html +++ /dev/null @@ -1,80 +0,0 @@ -{% extends "base.html" %} - -{% block script %} - - -{% endblock %} - -{% block dropdown %} - -{% endblock %} -{% block nav %} -{% if perms.server.can_approve %} -{% if outstanding.count > 0 %} - -{% else %} - -{% endif %} -{% endif %} -{% endblock %} -{% block content %} -{% if perms.server.can_approve %} - {% if outstanding.count > 0 %} -
- You have outstanding requests to approve. -
- {% endif %} -{% endif %} - - - - - - - - - - - - - - - - {% for computer in computers.all %} - - {% endfor %} - -
Serial NumberComputer NameUser NameLast Checked In
{{ computer.serial }}{{ computer.computername }}{{ computer.username }}{{ computer.last_checkin }}Info
- -{% endblock %} diff --git a/server/templates/server/manage_requests.html b/server/templates/server/manage_requests.html deleted file mode 100644 index 87b29a6..0000000 --- a/server/templates/server/manage_requests.html +++ /dev/null @@ -1,42 +0,0 @@ -{% extends "base.html" %} - -{% block script %} - - -{% endblock %} - -{% block nav %} - -{% endblock %} -{% block content %} -

Key Requests

- - - - - - - - - - - - - {% for the_request in requests %} - - {% endfor %} - -
Serial NumberComputer NameRequested ByReason for RequestDate Requested
{{ the_request.secret.computer.serial }}{{ the_request.secret.computer.computername }}{{ the_request.requesting_user }}{{ the_request.reason_for_request }}{{ the_request.date_requested }}Manage
- -{% endblock %} diff --git a/server/templates/server/new_computer_form.html b/server/templates/server/new_computer_form.html deleted file mode 100644 index c4ea466..0000000 --- a/server/templates/server/new_computer_form.html +++ /dev/null @@ -1,18 +0,0 @@ -{% extends "base.html" %} -{% load bootstrap4 %} -{% block content %} -
-
-

New Computer

-{% if error_message %}

{{ error_message }}

{% endif %} -
{% csrf_token %} - - - {% bootstrap_form form %} - -

- -
-
-
-{% endblock %} diff --git a/server/templates/server/new_secret_form.html b/server/templates/server/new_secret_form.html deleted file mode 100644 index 180d423..0000000 --- a/server/templates/server/new_secret_form.html +++ /dev/null @@ -1,19 +0,0 @@ -{% extends "base.html" %} -{% load bootstrap4 %} -{% block content %} -
-
-

New Secret

-

Computer: {{computer}}

-{% if error_message %}

{{ error_message }}

{% endif %} -
{% csrf_token %} - - - {% bootstrap_form form %} - -

- -
-
-
-{% endblock %} diff --git a/server/templates/server/request.html b/server/templates/server/request.html deleted file mode 100755 index ad76cbe..0000000 --- a/server/templates/server/request.html +++ /dev/null @@ -1,18 +0,0 @@ -{% extends "base.html" %} -{% load bootstrap4 %} -{% block content %} -

Request Secret

-

{{ secret.computer.computername }} ({{ secret.computer.serial }})

-

{{ secret.get_secret_type_display }} - {{ secret.date_escrowed }}

-{% if error_message %}

{{ error_message }}

{% endif %} -
{% csrf_token %} - - - {% bootstrap_form form %} - {% if not perms.server.can_approve %} -

- {% else %} -

- {% endif %} -
-{% endblock %} diff --git a/server/templates/server/retrieve.html b/server/templates/server/retrieve.html deleted file mode 100755 index f305ac5..0000000 --- a/server/templates/server/retrieve.html +++ /dev/null @@ -1,50 +0,0 @@ -{% extends "base.html" %} - -{% block script %} - - -{% endblock %} - -{% block nav %} - -{% endblock %} -{% block content %} -
-

{{ computer.computername }}

-

{{ computer.serial }}

-
- -
-
-

{{ the_request.secret.get_secret_type_display }}:
- - {% spaceless %} - {% for char in the_request.secret.secret %} - {% if char in "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" %} - {{ char }} - {% elif char in "0123456789" %} - {{ char }} - {% else %} - {{ char }} - {% endif %} - {% endfor %} - {% endspaceless %} - -

- -

This approval is valid for 7 days, after which you will need to submit another request for access.

-
-
- -{% endblock %} diff --git a/server/templates/server/secret_approved_button.html b/server/templates/server/secret_approved_button.html deleted file mode 100644 index 9244f42..0000000 --- a/server/templates/server/secret_approved_button.html +++ /dev/null @@ -1,7 +0,0 @@ -{% extends "server/secret_info.html" %} -{% block button %} - {% for the_request in approved|slice:":1" %} - Retrieve Key - - {% endfor %} -{% endblock %} diff --git a/server/templates/server/secret_info.html b/server/templates/server/secret_info.html deleted file mode 100644 index 95725ee..0000000 --- a/server/templates/server/secret_info.html +++ /dev/null @@ -1,81 +0,0 @@ -{% extends "base.html" %} - -{% block script %} - - -{% endblock %} - -{% block nav %} - -{% endblock %} -{% block content %} -
-
-

{{ computer.computername }}

-

{{ computer.serial }}

-
-
- -
-
- - - - - - - - -
Username{{ computer.username }}
Computer Name{{ computer.computername }}
Serial Number{{ computer.serial }}
Secret Type{{ secret.get_secret_type_display }}
Escrow Date{{ secret.date_escrowed }}
-
-
- {% block button %} - {% endblock %} -
-
- -{% if perms.server.can_approve %} -
-
-

Requests

- - - - - - - - - - - {% for the_request in requests %} - - - - - - - - - - - {% endfor %} - -
Requesting UserReason for RequestDate RequestedApproved ByApproval NotesDate Approved
{{ the_request.requesting_user }}{{ the_request.reason_for_request }}{{ the_request.date_requested }}{{ the_request.auth_user }}{{ the_request.reason_for_approval }}{{ the_request.date_approved }}
-
- -
-{% endif %} - -{% endblock %} diff --git a/server/templates/server/secret_request_button.html b/server/templates/server/secret_request_button.html deleted file mode 100644 index f5b4731..0000000 --- a/server/templates/server/secret_request_button.html +++ /dev/null @@ -1,12 +0,0 @@ -{% extends "server/secret_info.html" %} -{% block button %} - {% if not perms.server.can_approve %} - {% if can_request %} - Request Key - {% else %} - - {% endif %} - {% else %} - Get Key - {% endif %} -{% endblock %} diff --git a/server/tests.py b/server/tests.py deleted file mode 100644 index be7c506..0000000 --- a/server/tests.py +++ /dev/null @@ -1,30 +0,0 @@ -from django.test import TestCase, Client -from django.contrib.auth.models import User -from datetime import datetime -from server.models import Computer, Secret, Request - - -class RequestProcess(TestCase): - def test_request_passes_correct_data_to_template(self): - admin = User.objects.create_superuser("admin", "a@a.com", "sekrit") - tech = User.objects.create_user("tech", "a@a.com", "password") - tech.save() - tech_test_computer = Computer( - serial="TECHSERIAL", username="Daft Tech", computername="compy587" - ) - tech_test_computer.save() - test_secret = Secret( - computer=tech_test_computer, - secret="SHHH-DONT-TELL", - date_escrowed=datetime.now(), - ) - test_secret.save() - secret_request = Request(secret=test_secret, requesting_user=tech) - secret_request.save() - client = Client() - login_response = self.client.post( - "/login/", {"username": "admin", "password": "sekrit"}, follow=True - ) - response = self.client.get("/manage-requests/", follow=True) - print(response) - self.assertTrue(response.context["user"].is_authenticated) diff --git a/server/urls.py b/server/urls.py deleted file mode 100755 index 03ca584..0000000 --- a/server/urls.py +++ /dev/null @@ -1,31 +0,0 @@ -from django.urls import path -from . import views - -app_name = "server" - -urlpatterns = [ - # front. page - path("", views.index, name="home"), - path("ajax/", views.tableajax, name="tableajax"), - # Add computer - path("new/computer/", views.new_computer, name="new_computer"), - # Add secret - path("new/secret//", views.new_secret, name="new_secret"), - # secret info - path("info/secret//", views.secret_info, name="secret_info"), - # computerinfo - path("info//", views.computer_info, name="computer_info"), - path("info//", views.computer_info, name="computer_info_serial"), - # request - path("request//", views.request, name="request"), - # retrieve - path("retrieve//", views.retrieve, name="retrieve"), - # approve - path("approve//", views.approve, name="approve"), - # verify - path("verify///", views.verify, name="verify"), - # checkin - path("checkin/", views.checkin, name="checkin"), - # manage - path("manage-requests/", views.managerequests, name="managerequests"), -] diff --git a/server/views.py b/server/views.py deleted file mode 100644 index 8419e7a..0000000 --- a/server/views.py +++ /dev/null @@ -1,570 +0,0 @@ -import logging -from .models import * -from django.contrib.auth.decorators import login_required, permission_required -from django.template import RequestContext, Template, Context -import json -import pytz -import copy -from django.views.decorators.csrf import csrf_exempt, csrf_protect -from django.http import HttpResponse, Http404, JsonResponse -from django.contrib.auth.models import Permission, User -from django.conf import settings -from django.template.context_processors import csrf -from django.shortcuts import render, get_object_or_404, redirect -from datetime import datetime, timedelta -from django.db.models import Q -from .forms import * -from django.views.defaults import server_error -from django.core.mail import send_mail -from django.conf import settings -from django.urls import reverse -from django.utils.html import escape - -# Create your views here. -logger = logging.getLogger(__name__) - - -##clean up old requests -def cleanup(): - how_many_days = 7 - the_requests = Request.objects.filter( - date_approved__lte=datetime.now() - timedelta(days=how_many_days) - ).filter(current=True) - for the_req in the_requests: - the_req.current = False - the_req.save() - - -def get_server_version(): - current_dir = os.path.dirname(os.path.realpath(__file__)) - - with open( - os.path.join(os.path.dirname(current_dir), "fvserver", "version.plist"), "rb" - ) as f: - version = plistlib.load(f) - return version["version"] - - -##index view -@login_required -def index(request): - cleanup() - # show table with all the keys - computers = Computer.objects.none() - - if hasattr(settings, "ALL_APPROVE"): - if settings.ALL_APPROVE == True: - permissions = Permission.objects.all() - permission = Permission.objects.get(codename="can_approve") - if request.user.has_perm("server.can_approve") == False: - request.user.user_permissions.add(permission) - request.user.save() - ##get the number of oustanding requests - approved equals null - - outstanding = Request.objects.filter(approved__isnull=True) - if hasattr(settings, "APPROVE_OWN"): - if settings.APPROVE_OWN == False: - outstanding = outstanding.filter(~Q(requesting_user=request.user)) - c = {"user": request.user, "computers": computers, "outstanding": outstanding} - return render(request, "server/index.html", c) - - -@login_required -def tableajax(request): - """Table ajax for dataTables""" - # Pull our variables out of the GET request - get_data = request.GET["args"] - get_data = json.loads(get_data) - draw = get_data.get("draw", 0) - start = int(get_data.get("start", 0)) - length = int(get_data.get("length", 0)) - search_value = "" - if "search" in get_data: - if "value" in get_data["search"]: - search_value = get_data["search"]["value"] - - # default ordering - order_column = 2 - order_direction = "desc" - order_name = "" - if "order" in get_data: - order_column = get_data["order"][0]["column"] - order_direction = get_data["order"][0]["dir"] - for column in get_data.get("columns", None): - if column["data"] == order_column: - order_name = column["name"] - break - - machines = Computer.objects.all().values( - "id", "serial", "username", "computername", "last_checkin" - ) - - order_string = None - if len(order_name) != 0: - if order_direction == "desc": - order_string = "-%s" % order_name - else: - order_string = "%s" % order_name - - if len(search_value) != 0: - searched_machines = machines.filter( - Q(serial__icontains=search_value) - | Q(username__icontains=search_value) - | Q(computername__icontains=search_value) - | Q(last_checkin__icontains=search_value) - ) - - else: - searched_machines = machines - - if order_name != "info_button": - searched_machines = searched_machines.order_by(order_string) - - limited_machines = searched_machines[start : (start + length)] - - return_data = {} - return_data["draw"] = int(draw) - return_data["recordsTotal"] = machines.count() - return_data["recordsFiltered"] = return_data["recordsTotal"] - - return_data["data"] = [] - settings_time_zone = None - try: - settings_time_zone = pytz.timezone(settings.TIME_ZONE) - except Exception: - pass - - for machine in limited_machines: - if machine["last_checkin"]: - # formatted_date = pytz.utc.localize(machine.last_checkin) - if settings_time_zone: - formatted_date = ( - machine["last_checkin"] - .astimezone(settings_time_zone) - .strftime("%Y-%m-%d %H:%M %Z") - ) - else: - formatted_date = machine["last_checkin"].strftime("%Y-%m-%d %H:%M") - else: - formatted_date = "" - - serial_link = '%s' % ( - reverse("server:computer_info", args=[machine["id"]]), - escape(machine["serial"]), - ) - - computername_link = '%s' % ( - reverse("server:computer_info", args=[machine["id"]]), - escape(machine["computername"]), - ) - - info_button = 'Info' % ( - reverse("server:computer_info", args=[machine["id"]]) - ) - - list_data = [ - serial_link, - computername_link, - escape(machine["username"]), - formatted_date, - info_button, - ] - return_data["data"].append(list_data) - - return JsonResponse(return_data) - - -##view to see computer info -@login_required -def computer_info(request, computer_id=None): - cleanup() - try: - computer = get_object_or_404(Computer, pk=computer_id) - except: - computer = get_object_or_404(Computer, serial=computer_id) - can_request = None - approved = None - - # Get the secrets, annotated with whethere there are approvals for them - secrets = computer.secret_set.all().prefetch_related() - - for secret in secrets: - secret.approved = ( - Request.objects.filter(requesting_user=request.user) - .filter(approved=True) - .filter(current=True) - .filter(secret=secret) - ) - secret.pending = ( - Request.objects.filter(requesting_user=request.user) - .filter(approved__isnull=True) - .filter(secret=secret) - ) - - c = {"user": request.user, "computer": computer, "secrets": secrets} - - return render(request, "server/computer_info.html", c) - - -@login_required -def secret_info(request, secret_id): - cleanup() - - secret = get_object_or_404(Secret, pk=secret_id) - - computer = secret.computer - - ##check if the user has outstanding request for this - pending = secret.request_set.filter(requesting_user=request.user).filter( - approved__isnull=True - ) - if pending.count() == 0: - can_request = True - else: - can_request = False - ##if it's been approved, we'll show a link to retrieve the key - approved = ( - secret.request_set.filter(requesting_user=request.user) - .filter(approved=True) - .filter(current=True) - ) - requests = secret.request_set.all() - - c = { - "user": request.user, - "computer": computer, - "can_request": can_request, - "approved": approved, - "secret": secret, - "requests": requests, - } - if approved.count() != 0: - return render(request, "server/secret_approved_button.html", c) - else: - return render(request, "server/secret_request_button.html", c) - - -##request key view -@login_required -def request(request, secret_id): - ##we will auto approve this if the user has the right perms - secret = get_object_or_404(Secret, pk=secret_id) - approver = False - if request.user.has_perm("server.can_approve"): - approver = True - - if approver == True: - if hasattr(settings, "APPROVE_OWN"): - if settings.APPROVE_OWN == False: - approver = False - c = {} - c.update(csrf(request)) - if request.method == "POST": - form = RequestForm(request.POST) - if form.is_valid(): - new_request = form.save(commit=False) - new_request.requesting_user = request.user - new_request.secret = secret - new_request.save() - if approver: - new_request.auth_user = request.user - new_request.approved = True - new_request.date_approved = datetime.now() - new_request.save() - else: - # User isn't an approver, send an email to all of the approvers - perm = Permission.objects.get(codename="can_approve") - users = User.objects.filter( - Q(is_superuser=True) - | Q(groups__permissions=perm) - | Q(user_permissions=perm) - ).distinct() - - if hasattr(settings, "HOST_NAME"): - server_name = settings.HOST_NAME.rstrip("/") - else: - server_name = "http://crypt" - if hasattr(settings, "SEND_EMAIL"): - if settings.SEND_EMAIL == True: - for user in users: - if user.email: - email_message = """ There has been a new key request by %s. You can review this request at %s%s - """ % ( - request.user.username, - server_name, - reverse("server:approve", args=[new_request.id]), - ) - if hasattr(settings, "EMAIL_SENDER"): - email_sender = settings.EMAIL_SENDER - else: - email_sender = ( - "requests@%s" % request.META["SERVER_NAME"] - ) - - logger.info( - "[*] Sending request email to {} from {}".format( - user.email, email_sender - ) - ) - if settings.EMAIL_USER and settings.EMAIL_PASSWORD: - - authing_user = settings.EMAIL_USER - authing_password = settings.EMAIL_PASSWORD - logger.info( - "[*] Authing to mail server as {}".format( - authing_user - ) - ) - - send_mail( - "Crypt Key Request", - email_message, - email_sender, - [user.email], - fail_silently=True, - auth_user=authing_user, - auth_password=authing_password, - ) - else: - send_mail( - "Crypt Key Request", - email_message, - email_sender, - [user.email], - fail_silently=True, - ) - - ##if we're an approver, we'll redirect to the retrieve view - if approver: - return redirect("server:retrieve", new_request.id) - else: - return redirect("server:secret_info", secret.id) - else: - form = RequestForm() - c = {"form": form, "secret": secret} - return render(request, "server/request.html", c) - - -##retrieve key view -@login_required -def retrieve(request, request_id): - cleanup() - the_request = get_object_or_404(Request, pk=request_id) - if the_request.approved == True and the_request.current == True: - if hasattr(settings, "ROTATE_VIEWED_SECRETS"): - if settings.ROTATE_VIEWED_SECRETS: - the_request.secret.rotation_required = True - the_request.secret.save() - c = {"user": request.user, "the_request": the_request} - return render(request, "server/retrieve.html", c) - else: - raise Http404 - - -## approve key view -@permission_required("server.can_approve", login_url="/login/") -def approve(request, request_id): - the_request = get_object_or_404(Request, pk=request_id) - c = {} - c.update(csrf(request)) - if request.method == "POST": - form = ApproveForm(request.POST, instance=the_request) - if form.is_valid(): - new_request = form.save(commit=False) - new_request.auth_user = request.user - new_request.date_approved = datetime.now() - new_request.save() - - # Send an email to the requester with a link to retrieve (or not) - if hasattr(settings, "HOST_NAME"): - server_name = settings.HOST_NAME.rstrip("/") - else: - server_name = "http://crypt" - if new_request.approved == True: - request_status = "approved" - elif new_request.approved == False: - request_status = "denied" - if hasattr(settings, "SEND_EMAIL"): - if settings.SEND_EMAIL == True: - if new_request.requesting_user.email: - email_message = """ Your key request has been %s by %s. %s%s - """ % ( - request_status, - request.user.username, - server_name, - reverse("server:secret_info", args=[new_request.secret.id]), - ) - if hasattr(settings, "EMAIL_SENDER"): - email_sender = settings.EMAIL_SENDER - else: - email_sender = "requests@%s" % request.META["SERVER_NAME"] - - logger.info( - "[*] Sending approved/denied email to {} from {}".format( - new_request.requesting_user.email, email_sender - ) - ) - if settings.EMAIL_USER and settings.EMAIL_PASSWORD: - - authing_user = settings.EMAIL_USER - authing_password = settings.EMAIL_PASSWORD - logger.info( - "[*] Authing to mail server as {}".format(authing_user) - ) - - send_mail( - "Crypt Key Request", - email_message, - email_sender, - [new_request.requesting_user.email], - fail_silently=True, - auth_user=authing_user, - auth_password=authing_password, - ) - else: - send_mail( - "Crypt Key Request", - email_message, - email_sender, - [new_request.requesting_user.email], - fail_silently=True, - ) - return redirect("server:managerequests") - else: - form = ApproveForm(instance=the_request) - c = {"form": form, "user": request.user, "the_request": the_request} - return render(request, "server/approve.html", c) - - -##manage requests -@permission_required("server.can_approve", login_url="/login/") -def managerequests(request): - requests = Request.objects.filter(approved__isnull=True) - if hasattr(settings, "APPROVE_OWN"): - if settings.APPROVE_OWN == False: - requests = requests.filter(~Q(requesting_user=request.user)) - c = {"user": request.user, "requests": requests} - return render(request, "server/manage_requests.html", c) - - -# Add new manual computer -@login_required -def new_computer(request): - c = {} - c.update(csrf(request)) - if request.method == "POST": - form = ComputerForm(request.POST) - if form.is_valid(): - new_computer = form.save(commit=False) - new_computer.save() - form.save_m2m() - return redirect("server:computer_info", new_computer.id) - else: - form = ComputerForm() - c = {"form": form} - return render(request, "server/new_computer_form.html", c) - - -@login_required -def new_secret(request, computer_id): - c = {} - c.update(csrf(request)) - computer = get_object_or_404(Computer, pk=computer_id) - if request.method == "POST": - form_data = copy.copy(request.POST) - form_data["computer"] = computer.id - form = SecretForm(data=form_data) - if form.is_valid(): - new_secret = form.save(commit=False) - new_secret.computer = computer - new_secret.date_escrowed = datetime.now() - new_secret.save() - # form.save_m2m() - return redirect("server:computer_info", computer.id) - else: - form = SecretForm() - - c = {"form": form, "computer": computer} - return render(request, "server/new_secret_form.html", c) - - -# Verify key escrow -@csrf_exempt -def verify(request, serial, secret_type): - computer = get_object_or_404(Computer, serial=serial) - try: - secret = Secret.objects.filter( - computer=computer, secret_type=secret_type - ).latest("date_escrowed") - output = {"escrowed": True, "date_escrowed": secret.date_escrowed} - except Secret.DoesNotExist: - output = {"escrowed": False} - return JsonResponse(output) - - -##checkin view -@csrf_exempt -def checkin(request): - try: - serial_num = request.POST["serial"] - except: - return HttpResponse(status=500) - try: - recovery_pass = request.POST["recovery_password"] - except: - return HttpResponse(status=500) - - try: - macname = request.POST["macname"] - except: - macname = serial_num - - try: - user_name = request.POST["username"] - except: - return HttpResponse(status=500) - - try: - secret_type = request.POST["secret_type"] - except: - secret_type = "recovery_key" - - try: - computer = Computer.objects.get(serial=serial_num) - except Computer.DoesNotExist: - computer = Computer(serial=serial_num) - # computer = Computer(recovery_key=recovery_pass, serial=serial_num, last_checkin = datetime.now(), username=user_name, computername=macname) - computer.last_checkin = datetime.now() - computer.username = user_name - computer.computername = macname - computer.secret_type = secret_type - computer.save() - - new_secret_escrowed = True - - try: - secret = Secret( - computer=computer, - secret=recovery_pass, - secret_type=secret_type, - date_escrowed=datetime.now(), - ) - secret.save() - except ValidationError: - new_secret_escrowed = False - pass - - latest_secret = ( - Secret.objects.filter(secret_type=secret_type) - .filter(computer_id=computer.id) - .latest("date_escrowed") - ) - rotation_required = latest_secret.rotation_required - - c = { - "serial": computer.serial, - "username": computer.username, - "rotation_required": rotation_required, - "new_secret_escrowed": new_secret_escrowed, - } - return HttpResponse(json.dumps(c), content_type="application/json") diff --git a/set_build_no.py b/set_build_no.py deleted file mode 100755 index 498c2d2..0000000 --- a/set_build_no.py +++ /dev/null @@ -1,25 +0,0 @@ -#!/usr/bin/env python3 - -import os -import plistlib -import subprocess - -current_version = "3.4.1" -script_path = os.path.dirname(os.path.realpath(__file__)) - - -# based on http://tgoode.com/2014/06/05/sensible-way-increment-bundle-version-cfbundleversion-xcode - -print("Setting Version to Git rev-list --count") -cmd = ["git", "rev-list", "HEAD", "--count"] -build_number = subprocess.check_output(cmd) -# This will always be one commit behind, so this makes it current -build_number = int(build_number) + 1 - -version_number = "{}.{}".format(current_version, build_number) - -data = {"version": version_number} -plist_path = "{}/fvserver/version.plist".format(script_path) -file_name = open(plist_path, "wb") -plistlib.dump(data, file_name) -file_name.close() diff --git a/setup/requirements.txt b/setup/requirements.txt deleted file mode 100644 index c95a071..0000000 --- a/setup/requirements.txt +++ /dev/null @@ -1,44 +0,0 @@ -appdirs==1.4.4 -asgiref==3.5.2 -asn1crypto==1.5.1 -astroid==2.12.11 -attrs==22.1.0 -beautifulsoup4==4.11.1 -black==24.3.0 -cffi==1.15.0 -click==8.1.3 -cryptography==44.0.1 -Django==4.2.18 -django-bootstrap3==11.0.0 -django-bootstrap4==22.2 -django-debug-toolbar==3.7.0 -django-encrypted-model-fields==0.6.5 -django-extensions==3.2.1 -django-iam-dbauth==0.1.4 -docutils==0.19 -flake8==5.0.4 -gunicorn==22.0.0 -idna==3.7 -isort==5.10.1 -lazy-object-proxy==1.7.1 -mccabe==0.7.0 -meld3==2.0.1 -mypy-extensions==0.4.3 -pathspec==0.10.1 -psycopg2==2.9.4 -pyasn1==0.4.8 -pycodestyle==2.9.1 -pycparser==2.21 -pycrypto==2.6.1 -pyflakes==2.5.0 -pylint==2.15.4 -pytz==2022.5 -regex==2022.9.13 -selenium==4.15.1 -six==1.16.0 -soupsieve==2.3.2.post1 -sqlparse==0.5.0 -supervisor==4.2.4 -tomli==2.0.1 -whitenoise==6.2.0 -wrapt==1.14.1 \ No newline at end of file diff --git a/smtp.sh b/smtp.sh deleted file mode 100755 index a6be8ab..0000000 --- a/smtp.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -sudo python -m smtpd -n -c DebuggingServer localhost:25 diff --git a/static/.gitignore b/static/.gitignore deleted file mode 100644 index 5e7d273..0000000 --- a/static/.gitignore +++ /dev/null @@ -1,4 +0,0 @@ -# Ignore everything in this directory -* -# Except this file -!.gitignore diff --git a/templates/404.html b/templates/404.html deleted file mode 100755 index 63abf5b..0000000 --- a/templates/404.html +++ /dev/null @@ -1,9 +0,0 @@ -{% extends "base.html" %} - -{% block title %}Page not found{% endblock %} - -{% block content %} -

Page not found

- -

Sorry, but the requested page could not be found.

-{% endblock %} \ No newline at end of file diff --git a/templates/500.html b/templates/500.html deleted file mode 100755 index e9c2679..0000000 --- a/templates/500.html +++ /dev/null @@ -1,15 +0,0 @@ - - - - Page unavailable - - -

Page unavailable

- -

Sorry, but the requested page is unavailable due to a - server hiccup.

- -

Our engineers have been notified, so check back later.

- - \ No newline at end of file diff --git a/templates/admin/base_site.html b/templates/admin/base_site.html deleted file mode 100755 index 217ff2d..0000000 --- a/templates/admin/base_site.html +++ /dev/null @@ -1,23 +0,0 @@ -{% extends "admin/base.html" %} -{% load i18n %} - -{% block title %}Crypt{% endblock %} -{% block extrastyle %} - -{% endblock %} - -{% block branding %} -

Crypt

-{% endblock %} - -{% block nav-global %}{% endblock %} diff --git a/templates/base.html b/templates/base.html deleted file mode 100644 index 55bd4c2..0000000 --- a/templates/base.html +++ /dev/null @@ -1,93 +0,0 @@ -{% load i18n %} -{% load bootstrap4 %} - - - - - {% block title %}Crypt{% endblock %} - - - - - - - {# Load CSS and JavaScript #} - - -{% bootstrap_javascript jquery='full' %} - - - - - - - - - - - - - - -
- - - - {% block content %} - - {% endblock %} - -

Crypt Server version {{ CRYPT_VERSION }}

- -
- -{% block script %} -{% endblock%} - - diff --git a/templates/registration/login.html b/templates/registration/login.html deleted file mode 100755 index 040c2e7..0000000 --- a/templates/registration/login.html +++ /dev/null @@ -1,31 +0,0 @@ -{% extends "base.html" %} - -{% block content %} -{% if user.is_authenticated %} -

Welcome, {{ user.username }}.

-

-{% else %} - {% if form.errors %} -

Your username and password didn't match. Please try again.

- {% else %} -

Please log in.

- {% endif %} - -
- {% csrf_token %} - - - - - - - - - -
{{ form.username.label_tag }}{{ form.username }}
{{ form.password.label_tag }}{{ form.password }}
- - - -
-{% endif %} -{% endblock %} diff --git a/templates/registration/password_change_done.html b/templates/registration/password_change_done.html deleted file mode 100755 index 63b31c5..0000000 --- a/templates/registration/password_change_done.html +++ /dev/null @@ -1,7 +0,0 @@ -{% extends "base.html" %} -{% load bootstap_toolkit } -{% block content %} - -

Password changed!

-

Back to main page

-{% endblock %} diff --git a/templates/registration/password_change_form.html b/templates/registration/password_change_form.html deleted file mode 100755 index a52ba07..0000000 --- a/templates/registration/password_change_form.html +++ /dev/null @@ -1,29 +0,0 @@ -{% extends "base.html" %} -{% load bootstap_toolkit } -{% block content %} - -

Change password

- - -
- {% csrf_token %} - - - - - - - - - - - - - - - -
{{ form.old_password.label_tag }}{{ form.old_password }}
{{ form.new_password1.label_tag }}{{ form.new_password1 }}
{{ form.new_password2.label_tag }}{{ form.new_password2 }}
- - -
-{% endblock %} diff --git a/site_static/bootstrap/css/bootstrap-responsive.css b/web/static/bootstrap/css/bootstrap-responsive.css similarity index 100% rename from site_static/bootstrap/css/bootstrap-responsive.css rename to web/static/bootstrap/css/bootstrap-responsive.css diff --git a/site_static/bootstrap/css/bootstrap-responsive.min.css b/web/static/bootstrap/css/bootstrap-responsive.min.css similarity index 100% rename from site_static/bootstrap/css/bootstrap-responsive.min.css rename to web/static/bootstrap/css/bootstrap-responsive.min.css diff --git a/site_static/bootstrap/css/bootstrap.css b/web/static/bootstrap/css/bootstrap.css similarity index 100% rename from site_static/bootstrap/css/bootstrap.css rename to web/static/bootstrap/css/bootstrap.css diff --git a/site_static/bootstrap/css/bootstrap.min.css b/web/static/bootstrap/css/bootstrap.min.css similarity index 100% rename from site_static/bootstrap/css/bootstrap.min.css rename to web/static/bootstrap/css/bootstrap.min.css diff --git a/site_static/bootstrap/img/glyphicons-halflings-white.png b/web/static/bootstrap/img/glyphicons-halflings-white.png similarity index 100% rename from site_static/bootstrap/img/glyphicons-halflings-white.png rename to web/static/bootstrap/img/glyphicons-halflings-white.png diff --git a/site_static/bootstrap/img/glyphicons-halflings.png b/web/static/bootstrap/img/glyphicons-halflings.png similarity index 100% rename from site_static/bootstrap/img/glyphicons-halflings.png rename to web/static/bootstrap/img/glyphicons-halflings.png diff --git a/site_static/bootstrap/js/bootstrap.js b/web/static/bootstrap/js/bootstrap.js similarity index 100% rename from site_static/bootstrap/js/bootstrap.js rename to web/static/bootstrap/js/bootstrap.js diff --git a/site_static/bootstrap/js/bootstrap.min.js b/web/static/bootstrap/js/bootstrap.min.js similarity index 100% rename from site_static/bootstrap/js/bootstrap.min.js rename to web/static/bootstrap/js/bootstrap.min.js diff --git a/site_static/css/bootstrap.css b/web/static/css/bootstrap.css similarity index 100% rename from site_static/css/bootstrap.css rename to web/static/css/bootstrap.css diff --git a/site_static/css/bootstrap.min.css b/web/static/css/bootstrap.min.css similarity index 100% rename from site_static/css/bootstrap.min.css rename to web/static/css/bootstrap.min.css diff --git a/site_static/css/mixins.css b/web/static/css/mixins.css similarity index 100% rename from site_static/css/mixins.css rename to web/static/css/mixins.css diff --git a/site_static/css/styles.css b/web/static/css/styles.css similarity index 100% rename from site_static/css/styles.css rename to web/static/css/styles.css diff --git a/site_static/css/variables.css b/web/static/css/variables.css similarity index 100% rename from site_static/css/variables.css rename to web/static/css/variables.css diff --git a/site_static/dataTables/css/dataTables.bootstrap.css b/web/static/dataTables/css/dataTables.bootstrap.css similarity index 100% rename from site_static/dataTables/css/dataTables.bootstrap.css rename to web/static/dataTables/css/dataTables.bootstrap.css diff --git a/site_static/dataTables/css/dataTables.bootstrap.min.css b/web/static/dataTables/css/dataTables.bootstrap.min.css similarity index 100% rename from site_static/dataTables/css/dataTables.bootstrap.min.css rename to web/static/dataTables/css/dataTables.bootstrap.min.css diff --git a/site_static/dataTables/css/dataTables.bootstrap4.css b/web/static/dataTables/css/dataTables.bootstrap4.css similarity index 100% rename from site_static/dataTables/css/dataTables.bootstrap4.css rename to web/static/dataTables/css/dataTables.bootstrap4.css diff --git a/site_static/dataTables/css/dataTables.bootstrap4.min.css b/web/static/dataTables/css/dataTables.bootstrap4.min.css similarity index 100% rename from site_static/dataTables/css/dataTables.bootstrap4.min.css rename to web/static/dataTables/css/dataTables.bootstrap4.min.css diff --git a/site_static/dataTables/css/dataTables.foundation.css b/web/static/dataTables/css/dataTables.foundation.css similarity index 100% rename from site_static/dataTables/css/dataTables.foundation.css rename to web/static/dataTables/css/dataTables.foundation.css diff --git a/site_static/dataTables/css/dataTables.foundation.min.css b/web/static/dataTables/css/dataTables.foundation.min.css similarity index 100% rename from site_static/dataTables/css/dataTables.foundation.min.css rename to web/static/dataTables/css/dataTables.foundation.min.css diff --git a/site_static/dataTables/css/dataTables.jqueryui.css b/web/static/dataTables/css/dataTables.jqueryui.css similarity index 100% rename from site_static/dataTables/css/dataTables.jqueryui.css rename to web/static/dataTables/css/dataTables.jqueryui.css diff --git a/site_static/dataTables/css/dataTables.jqueryui.min.css b/web/static/dataTables/css/dataTables.jqueryui.min.css similarity index 100% rename from site_static/dataTables/css/dataTables.jqueryui.min.css rename to web/static/dataTables/css/dataTables.jqueryui.min.css diff --git a/site_static/dataTables/css/dataTables.semanticui.css b/web/static/dataTables/css/dataTables.semanticui.css similarity index 100% rename from site_static/dataTables/css/dataTables.semanticui.css rename to web/static/dataTables/css/dataTables.semanticui.css diff --git a/site_static/dataTables/css/dataTables.semanticui.min.css b/web/static/dataTables/css/dataTables.semanticui.min.css similarity index 100% rename from site_static/dataTables/css/dataTables.semanticui.min.css rename to web/static/dataTables/css/dataTables.semanticui.min.css diff --git a/site_static/dataTables/css/jquery.dataTables.css b/web/static/dataTables/css/jquery.dataTables.css similarity index 100% rename from site_static/dataTables/css/jquery.dataTables.css rename to web/static/dataTables/css/jquery.dataTables.css diff --git a/site_static/dataTables/css/jquery.dataTables.min.css b/web/static/dataTables/css/jquery.dataTables.min.css similarity index 100% rename from site_static/dataTables/css/jquery.dataTables.min.css rename to web/static/dataTables/css/jquery.dataTables.min.css diff --git a/site_static/dataTables/images/sort_asc.png b/web/static/dataTables/images/sort_asc.png similarity index 100% rename from site_static/dataTables/images/sort_asc.png rename to web/static/dataTables/images/sort_asc.png diff --git a/site_static/dataTables/images/sort_asc_disabled.png b/web/static/dataTables/images/sort_asc_disabled.png similarity index 100% rename from site_static/dataTables/images/sort_asc_disabled.png rename to web/static/dataTables/images/sort_asc_disabled.png diff --git a/site_static/dataTables/images/sort_both.png b/web/static/dataTables/images/sort_both.png similarity index 100% rename from site_static/dataTables/images/sort_both.png rename to web/static/dataTables/images/sort_both.png diff --git a/site_static/dataTables/images/sort_desc.png b/web/static/dataTables/images/sort_desc.png similarity index 100% rename from site_static/dataTables/images/sort_desc.png rename to web/static/dataTables/images/sort_desc.png diff --git a/site_static/dataTables/images/sort_desc_disabled.png b/web/static/dataTables/images/sort_desc_disabled.png similarity index 100% rename from site_static/dataTables/images/sort_desc_disabled.png rename to web/static/dataTables/images/sort_desc_disabled.png diff --git a/site_static/dataTables/js/dataTables.bootstrap.js b/web/static/dataTables/js/dataTables.bootstrap.js similarity index 100% rename from site_static/dataTables/js/dataTables.bootstrap.js rename to web/static/dataTables/js/dataTables.bootstrap.js diff --git a/site_static/dataTables/js/dataTables.bootstrap.min.js b/web/static/dataTables/js/dataTables.bootstrap.min.js similarity index 100% rename from site_static/dataTables/js/dataTables.bootstrap.min.js rename to web/static/dataTables/js/dataTables.bootstrap.min.js diff --git a/site_static/dataTables/js/dataTables.bootstrap4.js b/web/static/dataTables/js/dataTables.bootstrap4.js similarity index 100% rename from site_static/dataTables/js/dataTables.bootstrap4.js rename to web/static/dataTables/js/dataTables.bootstrap4.js diff --git a/site_static/dataTables/js/dataTables.bootstrap4.min.js b/web/static/dataTables/js/dataTables.bootstrap4.min.js similarity index 100% rename from site_static/dataTables/js/dataTables.bootstrap4.min.js rename to web/static/dataTables/js/dataTables.bootstrap4.min.js diff --git a/site_static/dataTables/js/dataTables.foundation.js b/web/static/dataTables/js/dataTables.foundation.js similarity index 100% rename from site_static/dataTables/js/dataTables.foundation.js rename to web/static/dataTables/js/dataTables.foundation.js diff --git a/site_static/dataTables/js/dataTables.foundation.min.js b/web/static/dataTables/js/dataTables.foundation.min.js similarity index 100% rename from site_static/dataTables/js/dataTables.foundation.min.js rename to web/static/dataTables/js/dataTables.foundation.min.js diff --git a/site_static/dataTables/js/dataTables.jqueryui.js b/web/static/dataTables/js/dataTables.jqueryui.js similarity index 100% rename from site_static/dataTables/js/dataTables.jqueryui.js rename to web/static/dataTables/js/dataTables.jqueryui.js diff --git a/site_static/dataTables/js/dataTables.jqueryui.min.js b/web/static/dataTables/js/dataTables.jqueryui.min.js similarity index 100% rename from site_static/dataTables/js/dataTables.jqueryui.min.js rename to web/static/dataTables/js/dataTables.jqueryui.min.js diff --git a/site_static/dataTables/js/dataTables.semanticui.js b/web/static/dataTables/js/dataTables.semanticui.js similarity index 100% rename from site_static/dataTables/js/dataTables.semanticui.js rename to web/static/dataTables/js/dataTables.semanticui.js diff --git a/site_static/dataTables/js/dataTables.semanticui.min.js b/web/static/dataTables/js/dataTables.semanticui.min.js similarity index 100% rename from site_static/dataTables/js/dataTables.semanticui.min.js rename to web/static/dataTables/js/dataTables.semanticui.min.js diff --git a/site_static/dataTables/js/jquery.dataTables.js b/web/static/dataTables/js/jquery.dataTables.js similarity index 100% rename from site_static/dataTables/js/jquery.dataTables.js rename to web/static/dataTables/js/jquery.dataTables.js diff --git a/site_static/dataTables/js/jquery.dataTables.min.js b/web/static/dataTables/js/jquery.dataTables.min.js similarity index 100% rename from site_static/dataTables/js/jquery.dataTables.min.js rename to web/static/dataTables/js/jquery.dataTables.min.js diff --git a/site_static/fonts/glyphicons-halflings-regular.eot b/web/static/fonts/glyphicons-halflings-regular.eot similarity index 100% rename from site_static/fonts/glyphicons-halflings-regular.eot rename to web/static/fonts/glyphicons-halflings-regular.eot diff --git a/site_static/fonts/glyphicons-halflings-regular.svg b/web/static/fonts/glyphicons-halflings-regular.svg similarity index 100% rename from site_static/fonts/glyphicons-halflings-regular.svg rename to web/static/fonts/glyphicons-halflings-regular.svg diff --git a/site_static/fonts/glyphicons-halflings-regular.ttf b/web/static/fonts/glyphicons-halflings-regular.ttf similarity index 100% rename from site_static/fonts/glyphicons-halflings-regular.ttf rename to web/static/fonts/glyphicons-halflings-regular.ttf diff --git a/site_static/fonts/glyphicons-halflings-regular.woff b/web/static/fonts/glyphicons-halflings-regular.woff similarity index 100% rename from site_static/fonts/glyphicons-halflings-regular.woff rename to web/static/fonts/glyphicons-halflings-regular.woff diff --git a/site_static/fonts/glyphicons-halflings-regular.woff2 b/web/static/fonts/glyphicons-halflings-regular.woff2 similarity index 100% rename from site_static/fonts/glyphicons-halflings-regular.woff2 rename to web/static/fonts/glyphicons-halflings-regular.woff2 diff --git a/site_static/img/atom-head-white.svg b/web/static/img/atom-head-white.svg similarity index 100% rename from site_static/img/atom-head-white.svg rename to web/static/img/atom-head-white.svg diff --git a/site_static/img/sal-logo-white.svg b/web/static/img/sal-logo-white.svg similarity index 100% rename from site_static/img/sal-logo-white.svg rename to web/static/img/sal-logo-white.svg diff --git a/site_static/img/select2-spinner.gif b/web/static/img/select2-spinner.gif similarity index 100% rename from site_static/img/select2-spinner.gif rename to web/static/img/select2-spinner.gif diff --git a/site_static/img/select2.png b/web/static/img/select2.png similarity index 100% rename from site_static/img/select2.png rename to web/static/img/select2.png diff --git a/site_static/img/select2x2.png b/web/static/img/select2x2.png similarity index 100% rename from site_static/img/select2x2.png rename to web/static/img/select2x2.png diff --git a/site_static/js/bootstrap.js b/web/static/js/bootstrap.js similarity index 100% rename from site_static/js/bootstrap.js rename to web/static/js/bootstrap.js diff --git a/site_static/js/bootstrap.min.js b/web/static/js/bootstrap.min.js similarity index 100% rename from site_static/js/bootstrap.min.js rename to web/static/js/bootstrap.min.js diff --git a/site_static/js/less-1.4.2.min.js b/web/static/js/less-1.4.2.min.js similarity index 100% rename from site_static/js/less-1.4.2.min.js rename to web/static/js/less-1.4.2.min.js diff --git a/site_static/js/main.js b/web/static/js/main.js similarity index 100% rename from site_static/js/main.js rename to web/static/js/main.js diff --git a/site_static/style.css b/web/static/style.css similarity index 100% rename from site_static/style.css rename to web/static/style.css diff --git a/web/templates/layouts/base.html b/web/templates/layouts/base.html new file mode 100644 index 0000000..2befa73 --- /dev/null +++ b/web/templates/layouts/base.html @@ -0,0 +1,66 @@ +{{define "base"}} + + + + + {{block "title" .}}Crypt{{end}} + + + + + + + + + + + + + +
+ + + {{block "content" .}}{{end}} + +

Crypt Server version {{.Version}}

+ +
+ +{{block "script" .}}{{end}} + + +{{end}} diff --git a/web/templates/pages/approve.html b/web/templates/pages/approve.html new file mode 100644 index 0000000..c178834 --- /dev/null +++ b/web/templates/pages/approve.html @@ -0,0 +1,22 @@ +{{define "title"}}Approve Request{{end}} + +{{define "content"}} +

Approve Request

+

{{.Computer.ComputerName}} ({{.Computer.Serial}})

+{{if .ErrorMessage}}

{{.ErrorMessage}}

{{end}} +
+ +
+ +
+ + +
+
+
+ + +
+

+
+{{end}} diff --git a/web/templates/pages/audit_log.html b/web/templates/pages/audit_log.html new file mode 100644 index 0000000..5bd7a18 --- /dev/null +++ b/web/templates/pages/audit_log.html @@ -0,0 +1,64 @@ +{{define "title"}}Audit Log{{end}} +{{define "content"}} +
+
+

Audit Log

+
+ + + {{if .AuditSearch}} + Clear + {{end}} +
+ + + + + + + + + + + + + {{if .AuditEvents}} + {{range .AuditEvents}} + + + + + + + + + {{end}} + {{else}} + + + + {{end}} + +
WhenActorTargetActionReasonIP
{{.CreatedAtFormatted}}{{.Actor}}{{.TargetUser}}{{.Action}}{{.Reason}}{{.IPAddress}}
No audit events yet.
+ {{if gt .AuditTotalPages 1}} + + {{end}} +
+
+{{end}} diff --git a/web/templates/pages/computer_info.html b/web/templates/pages/computer_info.html new file mode 100644 index 0000000..90bfb86 --- /dev/null +++ b/web/templates/pages/computer_info.html @@ -0,0 +1,86 @@ +{{define "title"}}Computer Info{{end}} + +{{define "script"}} + +{{end}} + +{{define "dropdown"}} + +{{end}} + +{{define "nav"}} + +{{end}} + +{{define "content"}} +
+
+

{{.Computer.ComputerName}}

+

{{.Computer.Serial}}

+
+
+ +
+
+ + + + + + + + + + + + + + + + + + + +
Username{{.Computer.Username}}
Computer Name{{.Computer.ComputerName}}
Serial Number{{.Computer.Serial}}
Last Checked In{{.Computer.LastCheckinFormatted}}
+
+
+ +
+
+

Secrets

+ + + + + + + + + + {{range .Secrets}} + + + + + + {{end}} + +
Secret TypeEscrow Date
{{.SecretTypeDisplay}}{{.DateEscrowedFormatted}} + Info / Request +
+
+
+{{end}} diff --git a/web/templates/pages/index.html b/web/templates/pages/index.html new file mode 100644 index 0000000..bd4dade --- /dev/null +++ b/web/templates/pages/index.html @@ -0,0 +1,80 @@ +{{define "title"}}Crypt{{end}} + +{{define "script"}} + +{{end}} + +{{define "dropdown"}} + +{{end}} + +{{define "nav"}} +{{if .User.CanApprove}} + {{if gt .OutstandingCount 0}} + + {{else}} + + {{end}} +{{end}} +{{end}} + +{{define "content"}} +{{if .User.CanApprove}} + {{if gt .OutstandingCount 0}} +
+ You have outstanding requests to approve. +
+ {{end}} +{{end}} + + + + + + + + + + + + + {{range .Computers}} + + + + + + + + {{end}} + +
Serial NumberComputer NameUser NameLast Checked In
{{.Serial}}{{.ComputerName}}{{.Username}}{{.LastCheckin}}Info
+{{end}} diff --git a/web/templates/pages/login.html b/web/templates/pages/login.html new file mode 100644 index 0000000..2caa634 --- /dev/null +++ b/web/templates/pages/login.html @@ -0,0 +1,33 @@ +{{define "title"}}Login{{end}} + +{{define "content"}} +{{if .User.IsAuthenticated}} +

Welcome, {{.User.Username}}.

+

+{{else}} +

Please log in.

+ {{if .SAMLAvailable}} +

Sign in with SAML

+
+ {{end}} +

Local login

+ {{if .ErrorMessage}} +
{{.ErrorMessage}}
+ {{end}} +
+ + + + + + + + + + +
+ + +
+{{end}} +{{end}} diff --git a/web/templates/pages/manage_requests.html b/web/templates/pages/manage_requests.html new file mode 100644 index 0000000..6ac803c --- /dev/null +++ b/web/templates/pages/manage_requests.html @@ -0,0 +1,44 @@ +{{define "title"}}Manage Requests{{end}} + +{{define "script"}} + +{{end}} + +{{define "nav"}} + +{{end}} + +{{define "content"}} +

Key Requests

+ + + + + + + + + + + + + {{range .ManageRequests}} + + + + + + + + + {{end}} + +
Serial NumberComputer NameRequested ByReason for RequestDate Requested
{{.Serial}}{{.ComputerName}}{{.RequestingUser}}{{.ReasonForRequest}}{{.DateRequested}}Manage
+ +{{end}} diff --git a/web/templates/pages/new_computer.html b/web/templates/pages/new_computer.html new file mode 100644 index 0000000..b5bb999 --- /dev/null +++ b/web/templates/pages/new_computer.html @@ -0,0 +1,26 @@ +{{define "title"}}New Computer{{end}} + +{{define "content"}} +
+
+

New Computer

+ {{if .ErrorMessage}}

{{.ErrorMessage}}

{{end}} +
+ +
+ + +
+
+ + +
+
+ + +
+

+
+
+
+{{end}} diff --git a/web/templates/pages/new_secret.html b/web/templates/pages/new_secret.html new file mode 100644 index 0000000..e6ac105 --- /dev/null +++ b/web/templates/pages/new_secret.html @@ -0,0 +1,31 @@ +{{define "title"}}New Secret{{end}} + +{{define "content"}} +
+
+

New Secret

+

Computer: {{.Computer.ComputerName}}

+ {{if .ErrorMessage}}

{{.ErrorMessage}}

{{end}} +
+ +
+ + +
+
+ + +
+
+ + +
+

+
+
+
+{{end}} diff --git a/web/templates/pages/password_change.html b/web/templates/pages/password_change.html new file mode 100644 index 0000000..40d9bb7 --- /dev/null +++ b/web/templates/pages/password_change.html @@ -0,0 +1,26 @@ +{{define "title"}}Change Password{{end}} + +{{define "content"}} +
+
+

Change Password

+ {{if .ErrorMessage}} +
{{.ErrorMessage}}
+ {{end}} +
+ + {{if .PasswordChangeRequiresCurrent}} +
+ + +
+ {{end}} +
+ + +
+ +
+
+
+{{end}} diff --git a/web/templates/pages/request.html b/web/templates/pages/request.html new file mode 100644 index 0000000..356d3ab --- /dev/null +++ b/web/templates/pages/request.html @@ -0,0 +1,20 @@ +{{define "title"}}Request Secret{{end}} + +{{define "content"}} +

Request Secret

+

{{.Computer.ComputerName}} ({{.Computer.Serial}})

+

{{.Secret.SecretTypeDisplay}} - {{.Secret.DateEscrowedFormatted}}

+{{if .ErrorMessage}}

{{.ErrorMessage}}

{{end}} +
+ +
+ + +
+ {{if .User.CanApprove}} +

+ {{else}} +

+ {{end}} +
+{{end}} diff --git a/web/templates/pages/retrieve.html b/web/templates/pages/retrieve.html new file mode 100644 index 0000000..f1b4395 --- /dev/null +++ b/web/templates/pages/retrieve.html @@ -0,0 +1,34 @@ +{{define "title"}}Retrieve Secret{{end}} + +{{define "script"}} + +{{end}} + +{{define "nav"}} + +{{end}} + +{{define "content"}} +
+
+

{{.Computer.ComputerName}}

+

{{.Computer.Serial}}

+
+
+ +
+
+

{{.Secret.SecretTypeDisplay}}:
+ {{range .SecretChars}}{{.Char}}{{end}} +

+ +

This approval is valid for 7 days, after which you will need to submit another request for access.

+
+
+{{end}} diff --git a/web/templates/pages/secret_info.html b/web/templates/pages/secret_info.html new file mode 100644 index 0000000..bae5e4d --- /dev/null +++ b/web/templates/pages/secret_info.html @@ -0,0 +1,83 @@ +{{define "title"}}Secret Info{{end}} + +{{define "script"}} + +{{end}} + +{{define "nav"}} + +{{end}} + +{{define "content"}} +
+
+

{{.Computer.ComputerName}}

+

{{.Computer.Serial}}

+
+
+ +
+
+ + + + + + + + +
Username{{.Computer.Username}}
Computer Name{{.Computer.ComputerName}}
Serial Number{{.Computer.Serial}}
Secret Type{{.Secret.SecretTypeDisplay}}
Escrow Date{{.Secret.DateEscrowedFormatted}}
+
+
+ {{if .User.CanApprove}} + {{if .RequestApproved}} + Retrieve Key + {{else}} + Get Key + {{end}} + {{else}} + {{if .CanRequest}} + Request Key + {{else}} + + {{end}} + {{end}} +
+
+ +{{if .User.CanApprove}} +
+
+

Requests

+ + + + + + + + + + + {{range .RequestsForSecret}} + + + + + + + + + {{end}} + +
Requesting UserReason for RequestDate RequestedApproved ByApproval NotesDate Approved
{{.RequestingUser}}{{.ReasonForRequest}}{{.DateRequestedFormatted}}{{.AuthUser}}{{.ReasonForApproval}}{{.DateApprovedFormatted}}
+
+
+{{end}} +{{end}} diff --git a/web/templates/pages/user_delete.html b/web/templates/pages/user_delete.html new file mode 100644 index 0000000..1d81256 --- /dev/null +++ b/web/templates/pages/user_delete.html @@ -0,0 +1,18 @@ +{{define "title"}}Delete User{{end}} + +{{define "content"}} +
+
+

Delete User

+ {{if .ErrorMessage}} +
{{.ErrorMessage}}
+ {{end}} +

Are you sure you want to delete {{.AdminUser.Username}}?

+
+ + + Cancel +
+
+
+{{end}} diff --git a/web/templates/pages/user_edit.html b/web/templates/pages/user_edit.html new file mode 100644 index 0000000..dc80260 --- /dev/null +++ b/web/templates/pages/user_edit.html @@ -0,0 +1,45 @@ +{{define "title"}}Edit User{{end}} + +{{define "content"}} +
+
+

Edit User

+ {{if .ErrorMessage}} +
{{.ErrorMessage}}
+ {{end}} +
+ +
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+ + Back +
+
+
+{{end}} diff --git a/web/templates/pages/user_list.html b/web/templates/pages/user_list.html new file mode 100644 index 0000000..c6e8846 --- /dev/null +++ b/web/templates/pages/user_list.html @@ -0,0 +1,42 @@ +{{define "title"}}Users{{end}} + +{{define "content"}} +
+
+
+

Users

+ New User +
+ + + + + + + + + + + + + + {{range .Users}} + + + + + + + + + + {{end}} + +
UsernameAdminCan ApproveLocal LoginMust ResetAuth Source
{{.Username}}{{if .IsStaff}}Yes{{else}}No{{end}}{{if .CanApprove}}Yes{{else}}No{{end}}{{if .LocalLoginEnabled}}Yes{{else}}No{{end}}{{if .MustResetPassword}}Yes{{else}}No{{end}}{{.AuthSource}} + Edit + Reset Password + Delete +
+
+
+{{end}} diff --git a/web/templates/pages/user_new.html b/web/templates/pages/user_new.html new file mode 100644 index 0000000..37c6b84 --- /dev/null +++ b/web/templates/pages/user_new.html @@ -0,0 +1,48 @@ +{{define "title"}}New User{{end}} + +{{define "content"}} +
+
+

New User

+ {{if .ErrorMessage}} +
{{.ErrorMessage}}
+ {{end}} +
+ +
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+ +
+
+
+{{end}} diff --git a/web/templates/pages/user_password.html b/web/templates/pages/user_password.html new file mode 100644 index 0000000..a18be23 --- /dev/null +++ b/web/templates/pages/user_password.html @@ -0,0 +1,22 @@ +{{define "title"}}Reset Password{{end}} + +{{define "content"}} +
+
+

Reset Password

+

User: {{.AdminUser.Username}}

+ {{if .ErrorMessage}} +
{{.ErrorMessage}}
+ {{end}} +
+ +
+ + +
+ + Back +
+
+
+{{end}}