commit ba89b18bce2b1a13287eaac8f66ef6d73c8a8070 Author: hailin Date: Sun May 18 22:23:26 2025 +0800 first commit, based chatbot-ui and self hosted supabase. diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..07f8151 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,427 @@ +# syntax=docker/dockerfile:1.6 +#============================================ storage-api ====================================================== +# Base stage for shared environment setup +FROM ubuntu:22.04 as s3base + +ENV DEBIAN_FRONTEND=noninteractive + +# 安装 node18 + 构建依赖 + xattr +RUN apt-get update && apt-get install -y \ + curl ca-certificates gnupg lsb-release \ + g++ make python3 libattr1 \ + && curl -fsSL https://deb.nodesource.com/setup_18.x | bash - \ + && apt-get install -y nodejs \ + && npm install -g npm@10.8.2 pnpm@10.9.0 + +WORKDIR /app +COPY storage_v1.19.1/package.json storage_v1.19.1/package-lock.json ./ + +# Dependencies stage - install and cache all dependencies +FROM s3base as dependencies +RUN npm ci +# Cache the installed node_modules for later stages +RUN cp -R node_modules /node_modules_cache + +# Build stage - use cached node_modules for building the application +FROM s3base as s3build +COPY --from=dependencies /node_modules_cache ./node_modules +COPY storage_v1.19.1/. . +RUN npm run build + +# Production dependencies stage - use npm cache to install only production dependencies +FROM s3base as production-deps +COPY --from=dependencies /node_modules_cache ./node_modules +RUN npm ci --production + +# Final stage - for the production build +FROM s3base as s3final +# ARG VERSION +# ENV VERSION=$VERSION +COPY storage_v1.19.1/migrations /migrations + +# Copy production node_modules from the production dependencies stage +COPY --from=production-deps /app/node_modules /node_modules +# Copy build artifacts from the build stage +COPY --from=s3build /app/dist /dist + +#EXPOSE 5000 +#CMD ["node", "dist/start/server.js"] + +#============================================ chatdesk-ui ======================================================== +FROM nvcr.io/nvidia/tritonserver:24.04-py3-min AS chataibuilder + +RUN curl -fsSL https://deb.nodesource.com/setup_18.x | bash - && \ + apt-get update && \ + apt-get install -y nodejs && \ + npm install -g npm@10.8.2 pnpm@10.9.0 + +WORKDIR /app + +# 拷贝依赖文件并安装生产依赖 +COPY chatdesk-ui/package.json chatdesk-ui/package-lock.json ./ +RUN npm ci + +# 拷贝全部源码 +COPY chatdesk-ui/. . + +# 构建项目 +RUN npm run build + + +#============================================ gotrue build====================================================== +FROM golang:1.22.3-alpine3.20 as authbuild + +RUN apk add --no-cache make git + +WORKDIR /go/src/github.com/supabase/auth + +# Pulling dependencies +COPY auth_v2.169.0/Makefile auth_v2.169.0/go.* ./ +RUN make deps + +# Building stuff +COPY auth_v2.169.0/. /go/src/github.com/supabase/auth + +# Make sure you change the RELEASE_VERSION value before publishing an image. +RUN GO111MODULE=on CGO_ENABLED=0 GOOS=linux RELEASE_VERSION=2.169.0 make build + + +#============================================ postgres ==================================================== +FROM nvcr.io/nvidia/tritonserver:24.04-py3-min as base + + + +#=========================================== kong ========================================================= +ARG ASSET=ce +ENV ASSET $ASSET + +ARG EE_PORTS + +COPY docker-kong_v2.8.1/ubuntu/kong.deb /tmp/kong.deb + +ARG KONG_VERSION=2.8.1 +ENV KONG_VERSION $KONG_VERSION + +ARG KONG_AMD64_SHA="10d12d23e5890414d666663094d51a42de41f8a9806fbc0baaf9ac4d37794361" +ARG KONG_ARM64_SHA="61c13219ef64dac9aeae5ae775411e8cfcd406f068cf3e75d463f916ae6513cb" + +# hadolint ignore=DL3015 +RUN set -ex; \ + arch=$(dpkg --print-architecture); \ + case "${arch}" in \ + amd64) KONG_SHA256=$KONG_AMD64_SHA ;; \ + arm64) KONG_SHA256=$KONG_ARM64_SHA ;; \ + esac; \ + apt-get update \ + && if [ "$ASSET" = "ce" ] ; then \ + apt-get install -y curl \ + && UBUNTU_CODENAME=focal \ + && KONG_REPO=$(echo ${KONG_VERSION%.*} | sed 's/\.//') \ + && curl -fL https://packages.konghq.com/public/gateway-$KONG_REPO/deb/ubuntu/pool/$UBUNTU_CODENAME/main/k/ko/kong_$KONG_VERSION/kong_${KONG_VERSION}_$arch.deb -o /tmp/kong.deb \ + && apt-get purge -y curl \ + && echo "$KONG_SHA256 /tmp/kong.deb" | sha256sum -c -; \ + else \ + # this needs to stay inside this "else" block so that it does not become part of the "official images" builds (https://github.com/docker-library/official-images/pull/11532#issuecomment-996219700) + apt-get upgrade -y ; \ + fi; \ + apt-get install -y --no-install-recommends unzip git \ + # Please update the ubuntu install docs if the below line is changed so that + # end users can properly install Kong along with its required dependencies + # and that our CI does not diverge from our docs. + && apt install --yes /tmp/kong.deb \ + && rm -rf /var/lib/apt/lists/* \ + && rm -rf /tmp/kong.deb \ + && chown kong:0 /usr/local/bin/kong \ + && chown -R kong:0 /usr/local/kong \ + && ln -s /usr/local/openresty/bin/resty /usr/local/bin/resty \ + && ln -s /usr/local/openresty/luajit/bin/luajit /usr/local/bin/luajit \ + && ln -s /usr/local/openresty/luajit/bin/luajit /usr/local/bin/lua \ + && ln -s /usr/local/openresty/nginx/sbin/nginx /usr/local/bin/nginx \ + && if [ "$ASSET" = "ce" ] ; then \ + kong version ; \ + fi + +COPY --chmod=0755 docker-kong_v2.8.1/ubuntu/docker-entrypoint.sh /supabase/kong/docker-entrypoint.sh + + + +ARG postgresql_major=15 +ARG postgresql_release=${postgresql_major}.1 + +# Bump default build arg to build a package from source +# Bump vars.yml to specify runtime package version +ARG sfcgal_release=1.3.10 +ARG postgis_release=3.3.2 +ARG pgrouting_release=3.4.1 +ARG pgtap_release=1.2.0 +ARG pg_cron_release=1.6.2 +ARG pgaudit_release=1.7.0 +ARG pgjwt_release=9742dab1b2f297ad3811120db7b21451bca2d3c9 +ARG pgsql_http_release=1.5.0 +ARG plpgsql_check_release=2.2.5 +ARG pg_safeupdate_release=1.4 +ARG timescaledb_release=2.9.1 +ARG wal2json_release=2_5 +ARG pljava_release=1.6.4 +ARG plv8_release=3.1.5 +ARG pg_plan_filter_release=5081a7b5cb890876e67d8e7486b6a64c38c9a492 +ARG pg_net_release=0.7.1 +ARG rum_release=1.3.13 +ARG pg_hashids_release=cd0e1b31d52b394a0df64079406a14a4f7387cd6 +ARG libsodium_release=1.0.18 +ARG pgsodium_release=3.1.6 +ARG pg_graphql_release=1.5.11 +ARG pg_stat_monitor_release=1.1.1 +ARG pg_jsonschema_release=0.1.4 +ARG pg_repack_release=1.4.8 +ARG vault_release=0.2.8 +ARG groonga_release=12.0.8 +ARG pgroonga_release=2.4.0 +ARG wrappers_release=0.3.0 +ARG hypopg_release=1.3.1 +ARG pgvector_release=0.4.0 +ARG pg_tle_release=1.3.2 +ARG index_advisor_release=0.2.0 +ARG supautils_release=2.2.0 +ARG wal_g_release=2.0.1 + +#FROM nvcr.io/nvidia/tritonserver:24.04-py3-min as base + +RUN apt update -y && apt install -y \ + curl \ + gnupg \ + lsb-release \ + software-properties-common \ + wget \ + sudo \ + && apt clean + + +RUN adduser --system --home /var/lib/postgresql --no-create-home --shell /bin/bash --group --gecos "PostgreSQL administrator" postgres +RUN adduser --system --no-create-home --shell /bin/bash --group wal-g +RUN curl --proto '=https' --tlsv1.2 -sSf -L https://install.determinate.systems/nix | sh -s -- install linux \ +--init none \ +--no-confirm \ +--extra-conf "substituters = https://cache.nixos.org https://nix-postgres-artifacts.s3.amazonaws.com" \ +--extra-conf "trusted-public-keys = nix-postgres-artifacts:dGZlQOvKcNEjvT7QEAJbcV6b6uk7VF/hWMjhYleiaLI=% cache.nixos.org-1:6NCHdD59X431o0gWypbMrAURkbJ16ZPMQFGspcDShjY=" + +ENV PATH="${PATH}:/nix/var/nix/profiles/default/bin" + +COPY postgres_15.8.1.044/. /nixpg + +WORKDIR /nixpg + +RUN nix profile install .#psql_15/bin + + + +WORKDIR / + + +RUN mkdir -p /usr/lib/postgresql/bin \ + /usr/lib/postgresql/share/postgresql \ + /usr/share/postgresql \ + /var/lib/postgresql \ + && chown -R postgres:postgres /usr/lib/postgresql \ + && chown -R postgres:postgres /var/lib/postgresql \ + && chown -R postgres:postgres /usr/share/postgresql + +# Create symbolic links +RUN ln -s /nix/var/nix/profiles/default/bin/* /usr/lib/postgresql/bin/ \ + && ln -s /nix/var/nix/profiles/default/bin/* /usr/bin/ \ + && chown -R postgres:postgres /usr/bin + +# Create symbolic links for PostgreSQL shares +RUN ln -s /nix/var/nix/profiles/default/share/postgresql/* /usr/lib/postgresql/share/postgresql/ +RUN ln -s /nix/var/nix/profiles/default/share/postgresql/* /usr/share/postgresql/ +RUN chown -R postgres:postgres /usr/lib/postgresql/share/postgresql/ +RUN chown -R postgres:postgres /usr/share/postgresql/ +# Create symbolic links for contrib directory +RUN mkdir -p /usr/lib/postgresql/share/postgresql/contrib \ + && find /nix/var/nix/profiles/default/share/postgresql/contrib/ -mindepth 1 -type d -exec sh -c 'for dir do ln -s "$dir" "/usr/lib/postgresql/share/postgresql/contrib/$(basename "$dir")"; done' sh {} + \ + && chown -R postgres:postgres /usr/lib/postgresql/share/postgresql/contrib/ + +RUN chown -R postgres:postgres /usr/lib/postgresql + +RUN ln -sf /usr/lib/postgresql/share/postgresql/timezonesets /usr/share/postgresql/timezonesets + +# 设置非交互模式,避免tzdata需要交互输入 +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && \ + apt-get install -y --no-install-recommends tzdata + +RUN ln -fs /usr/share/zoneinfo/Etc/UTC /etc/localtime && \ + dpkg-reconfigure --frontend noninteractive tzdata + +# 设置回默认交互模式 +ENV DEBIAN_FRONTEND=interactive + +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + build-essential \ + checkinstall \ + cmake + +ENV PGDATA=/var/lib/postgresql/data + +#################### +# setup-wal-g.yml +#################### +FROM base as walg +ARG wal_g_release +# ADD "https://github.com/wal-g/wal-g/releases/download/v${wal_g_release}/wal-g-pg-ubuntu-20.04-${TARGETARCH}.tar.gz" /tmp/wal-g.tar.gz +RUN arch=$([ "$TARGETARCH" = "arm64" ] && echo "aarch64" || echo "$TARGETARCH") && \ + apt-get update && apt-get install -y --no-install-recommends curl && \ + curl -kL "https://github.com/wal-g/wal-g/releases/download/v${wal_g_release}/wal-g-pg-ubuntu-20.04-aarch64.tar.gz" -o /tmp/wal-g.tar.gz && \ + tar -xvf /tmp/wal-g.tar.gz -C /tmp && \ + rm -rf /tmp/wal-g.tar.gz && \ + mv /tmp/wal-g-pg-ubuntu*20.04-aarch64 /tmp/wal-g + +# #################### +# # Download gosu for easy step-down from root +# #################### +FROM base as gosu +ARG TARGETARCH +# Install dependencies +RUN apt-get update && apt-get install -y --no-install-recommends supervisor \ + gnupg \ + ca-certificates \ + && rm -rf /var/lib/apt/lists/* +# Download binary +ARG GOSU_VERSION=1.16 +ARG GOSU_GPG_KEY=B42F6819007F00F88E364FD4036A9C25BF357DD4 +ADD https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-$TARGETARCH \ + /usr/local/bin/gosu +ADD https://github.com/tianon/gosu/releases/download/$GOSU_VERSION/gosu-$TARGETARCH.asc \ + /usr/local/bin/gosu.asc +# Verify checksum +RUN gpg --batch --keyserver hkps://keys.openpgp.org --recv-keys $GOSU_GPG_KEY && \ + gpg --batch --verify /usr/local/bin/gosu.asc /usr/local/bin/gosu && \ + gpgconf --kill all && \ + chmod +x /usr/local/bin/gosu + +# #################### +# # Build final image +# #################### +FROM gosu as production +RUN id postgres || (echo "postgres user does not exist" && exit 1) +# # Setup extensions +COPY --from=walg /tmp/wal-g /usr/local/bin/ + +# # Initialise configs +COPY --chown=postgres:postgres postgres_15.8.1.044/ansible/files/postgresql_config/postgresql.conf.j2 /etc/postgresql/postgresql.conf +COPY --chown=postgres:postgres postgres_15.8.1.044/ansible/files/postgresql_config/pg_hba.conf.j2 /etc/postgresql/pg_hba.conf +COPY --chown=postgres:postgres postgres_15.8.1.044/ansible/files/postgresql_config/pg_ident.conf.j2 /etc/postgresql/pg_ident.conf +COPY --chown=postgres:postgres postgres_15.8.1.044/ansible/files/postgresql_config/postgresql-stdout-log.conf /etc/postgresql/logging.conf +COPY --chown=postgres:postgres postgres_15.8.1.044/ansible/files/postgresql_config/supautils.conf.j2 /etc/postgresql-custom/supautils.conf +COPY --chown=postgres:postgres postgres_15.8.1.044/ansible/files/postgresql_extension_custom_scripts /etc/postgresql-custom/extension-custom-scripts +COPY --chown=postgres:postgres postgres_15.8.1.044/ansible/files/pgsodium_getkey_urandom.sh.j2 /usr/lib/postgresql/bin/pgsodium_getkey.sh +COPY --chown=postgres:postgres postgres_15.8.1.044/ansible/files/postgresql_config/custom_read_replica.conf.j2 /etc/postgresql-custom/read-replica.conf +COPY --chown=postgres:postgres postgres_15.8.1.044/ansible/files/postgresql_config/custom_walg.conf.j2 /etc/postgresql-custom/wal-g.conf +COPY --chown=postgres:postgres postgres_15.8.1.044/ansible/files/walg_helper_scripts/wal_fetch.sh /home/postgres/wal_fetch.sh +COPY postgres_15.8.1.044/ansible/files/walg_helper_scripts/wal_change_ownership.sh /root/wal_change_ownership.sh + +RUN sed -i \ + -e "s|#unix_socket_directories = '/tmp'|unix_socket_directories = '/var/run/postgresql'|g" \ + -e "s|#session_preload_libraries = ''|session_preload_libraries = 'supautils'|g" \ + -e "s|#include = '/etc/postgresql-custom/supautils.conf'|include = '/etc/postgresql-custom/supautils.conf'|g" \ + -e "s|#include = '/etc/postgresql-custom/wal-g.conf'|include = '/etc/postgresql-custom/wal-g.conf'|g" /etc/postgresql/postgresql.conf && \ + echo "cron.database_name = 'postgres'" >> /etc/postgresql/postgresql.conf && \ + #echo "pljava.libjvm_location = '/usr/lib/jvm/java-11-openjdk-${TARGETARCH}/lib/server/libjvm.so'" >> /etc/postgresql/postgresql.conf && \ + echo "pgsodium.getkey_script= '/usr/lib/postgresql/bin/pgsodium_getkey.sh'" >> /etc/postgresql/postgresql.conf && \ + echo 'auto_explain.log_min_duration = 10s' >> /etc/postgresql/postgresql.conf && \ + usermod -aG postgres wal-g && \ + mkdir -p /etc/postgresql-custom && \ + chown postgres:postgres /etc/postgresql-custom + +# # Include schema migrations +COPY postgres_15.8.1.044/migrations/db /docker-entrypoint-initdb.d/ +COPY postgres_15.8.1.044/ansible/files/pgbouncer_config/pgbouncer_auth_schema.sql /docker-entrypoint-initdb.d/init-scripts/00-schema.sql +COPY postgres_15.8.1.044/ansible/files/stat_extension.sql /docker-entrypoint-initdb.d/migrations/00-extension.sql + +# # Add upstream entrypoint script +COPY --from=gosu /usr/local/bin/gosu /usr/local/bin/gosu +ADD --chmod=0755 \ + https://github.com/docker-library/postgres/raw/master/15/bullseye/docker-entrypoint.sh \ + /usr/local/bin/ + +RUN mkdir -p /var/run/postgresql && chown postgres:postgres /var/run/postgresql + +COPY ./supabase ./supabase +RUN chmod +x /supabase/postgres/wrapper.sh /supabase/postgrest/wrapper.sh /supabase/gotrue/wrapper.sh /supabase/storage-api/wrapper.sh /supabase/kong/wrapper.sh + +#ENTRYPOINT ["docker-entrypoint.sh"] +ENTRYPOINT ["supervisord"] + +HEALTHCHECK --interval=2s --timeout=2s --retries=10 CMD pg_isready -U postgres -h localhost +STOPSIGNAL SIGINT +#EXPOSE 5432 + +ENV POSTGRES_HOST=/var/run/postgresql +ENV POSTGRES_USER=supabase_admin +ENV POSTGRES_DB=postgres +RUN apt-get update && apt-get install -y --no-install-recommends \ + locales \ + && rm -rf /var/lib/apt/lists/* && \ + localedef -i en_US -c -f UTF-8 -A /usr/share/locale/locale.alias en_US.UTF-8 \ + && localedef -i C -c -f UTF-8 -A /usr/share/locale/locale.alias C.UTF-8 +RUN echo "C.UTF-8 UTF-8" > /etc/locale.gen && echo "en_US.UTF-8 UTF-8" >> /etc/locale.gen && locale-gen +ENV LANG en_US.UTF-8 +ENV LANGUAGE en_US:en +ENV LC_ALL en_US.UTF-8 +ENV LC_CTYPE=C.UTF-8 +ENV LC_COLLATE=C.UTF-8 +ENV LOCALE_ARCHIVE /usr/lib/locale/locale-archive + +CMD ["-c", "supabase/postgres/supervisord.conf"] +#CMD ["postgres", "-D", "/etc/postgresql"] + +#============================================ postgrest =========================================== +RUN apt-get update -y \ + && apt install -y --no-install-recommends libpq-dev zlib1g-dev jq gcc libnuma-dev \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +COPY postgrest_v12.2.8/postgrest /usr/bin/postgrest +RUN chmod +x /usr/bin/postgrest + +#=========================================== goture include============================================== +#RUN useradd -m -u 1000 supabase + +RUN apt-get update && apt-get install -y --no-install-recommends ca-certificates && rm -rf /var/lib/apt/lists/* +COPY --from=authbuild /go/src/github.com/supabase/auth/auth /usr/local/bin/auth +COPY --from=authbuild /go/src/github.com/supabase/auth/migrations /usr/local/etc/auth/migrations/ +RUN ln -s /usr/local/bin/auth /usr/local/bin/gotrue + +ENV GOTRUE_DB_MIGRATIONS_PATH=/usr/local/etc/auth/migrations + +#========================================== storage-api ==================================================== +ARG VERSION +ENV VERSION=$VERSION + +RUN curl -fsSL https://deb.nodesource.com/setup_18.x | bash - && \ + apt-get update && \ + apt-get install -y nodejs && \ + npm install -g npm@10.8.2 pnpm@10.9.0 + +COPY --from=s3final node_modules /supabase/storage-api/node_modules +COPY --from=s3final dist /supabase/storage-api/dist +COPY --from=s3final migrations /supabase/storage-api/migrations + +#========================================= chatai-ui ======================================================== +# # 拷贝依赖声明并安装仅生产依赖 +COPY chatdesk-ui/package.json chatdesk-ui/package-lock.json supabase/chatdesk/ +WORKDIR /supabase/chatdesk +RUN npm ci + +# 拷贝构建产物和依赖 +COPY --from=chataibuilder /app/.next ./.next +COPY --from=chataibuilder /app/public ./public +COPY --from=chataibuilder /app/next.config.js ./next.config.js +COPY chatdesk-ui/.env.local ./.env.local +COPY chatdesk-ui/supabase ./supabase + +WORKDIR / +ENV NODE_ENV=production +EXPOSE 3030 \ No newline at end of file diff --git a/auth_v2.169.0/.dockerignore b/auth_v2.169.0/.dockerignore new file mode 100644 index 0000000..8efdb51 --- /dev/null +++ b/auth_v2.169.0/.dockerignore @@ -0,0 +1,3 @@ +/hack/ +/vendor/ +/www/ diff --git a/auth_v2.169.0/.gitattributes b/auth_v2.169.0/.gitattributes new file mode 100644 index 0000000..bbfc097 --- /dev/null +++ b/auth_v2.169.0/.gitattributes @@ -0,0 +1,41 @@ +# Set the default behavior +* text=auto + +# Go files +*.mod text eol=lf +*.sum text eol=lf +*.go text eol=lf + +# Serialization +*.yml eol=lf +*.yaml eol=lf +*.toml eol=lf +*.json eol=lf + +# Scripts +*.sh eol=lf + +# DB files +*.sql eol=lf + +# Html +*.html eol=lf + +# Text and markdown files +*.txt text eol=lf +*.md text eol=lf + +# Environment files/examples +*.env text eol=lf + +# Docker files +.dockerignore text eol=lf +Dockerfile* text eol=lf + +# Makefile +Makefile text eol=lf + +# Git files +.gitignore text eol=lf +.gitattributes text eol=lf +.gitkeep text eol=lf \ No newline at end of file diff --git a/auth_v2.169.0/.gitignore b/auth_v2.169.0/.gitignore new file mode 100644 index 0000000..acab1be --- /dev/null +++ b/auth_v2.169.0/.gitignore @@ -0,0 +1,18 @@ +.env* +vendor/ +gotrue +gotrue-arm64 +gotrue.exe +auth +auth-arm64 +auth.exe + +coverage.out + +.DS_Store +.vscode +www/dist/ +www/.DS_Store +www/node_modules +npm-debug.log +.data diff --git a/auth_v2.169.0/.releaserc b/auth_v2.169.0/.releaserc new file mode 100644 index 0000000..32f0e45 --- /dev/null +++ b/auth_v2.169.0/.releaserc @@ -0,0 +1,10 @@ +{ + "branches": [ + "master" + ], + "plugins": [ + "@semantic-release/commit-analyzer", + "@semantic-release/release-notes-generator", + "@semantic-release/github" + ] +} diff --git a/auth_v2.169.0/CHANGELOG.md b/auth_v2.169.0/CHANGELOG.md new file mode 100644 index 0000000..551852f --- /dev/null +++ b/auth_v2.169.0/CHANGELOG.md @@ -0,0 +1,601 @@ +# Changelog + +## [2.169.0](https://github.com/supabase/auth/compare/v2.168.0...v2.169.0) (2025-01-27) + + +### Features + +* add an optional burstable rate limiter ([#1924](https://github.com/supabase/auth/issues/1924)) ([1f06f58](https://github.com/supabase/auth/commit/1f06f58e1434b91612c0d96c8c0435d26570f3e2)) +* cover 100% of crypto with tests ([#1892](https://github.com/supabase/auth/issues/1892)) ([174198e](https://github.com/supabase/auth/commit/174198e56f8e9b8470a717d0021c626130288d2e)) + + +### Bug Fixes + +* convert refreshed_at to UTC before updating ([#1916](https://github.com/supabase/auth/issues/1916)) ([a4c692f](https://github.com/supabase/auth/commit/a4c692f6cb1b8bf4c47ea012872af5ce93382fbf)) +* correct casing of API key authentication in openapi.yaml ([0cfd177](https://github.com/supabase/auth/commit/0cfd177b8fb1df8f62e84fbd3761ef9f90c384de)) +* improve invalid channel error message returned ([#1908](https://github.com/supabase/auth/issues/1908)) ([f72f0ee](https://github.com/supabase/auth/commit/f72f0eee328fa0aa041155f5f5dc305f0874d2bf)) +* improve saml assertion logging ([#1915](https://github.com/supabase/auth/issues/1915)) ([d6030cc](https://github.com/supabase/auth/commit/d6030ccd271a381e2a6ababa11a5beae4b79e5c3)) + +## [2.168.0](https://github.com/supabase/auth/compare/v2.167.0...v2.168.0) (2025-01-06) + + +### Features + +* set `email_verified` to true on all identities with the verified email ([#1902](https://github.com/supabase/auth/issues/1902)) ([307892f](https://github.com/supabase/auth/commit/307892f85b39150074fbb80b9c8f45ac3312aae2)) + +## [2.167.0](https://github.com/supabase/auth/compare/v2.166.0...v2.167.0) (2024-12-24) + + +### Features + +* fix argon2 parsing and comparison ([#1887](https://github.com/supabase/auth/issues/1887)) ([9dbe6ef](https://github.com/supabase/auth/commit/9dbe6ef931ae94e621d55a5f7aea4b7ee0449949)) + +## [2.166.0](https://github.com/supabase/auth/compare/v2.165.0...v2.166.0) (2024-12-23) + + +### Features + +* switch to googleapis/release-please-action, bump to 2.166.0 ([#1883](https://github.com/supabase/auth/issues/1883)) ([11a312f](https://github.com/supabase/auth/commit/11a312fcf77771b3732f2f439078225895df7a85)) + + +### Bug Fixes + +* check if session is nil ([#1873](https://github.com/supabase/auth/issues/1873)) ([fd82601](https://github.com/supabase/auth/commit/fd82601917adcd9f8c38263953eb1ef098b26b7f)) +* email_verified field not being updated on signup confirmation ([#1868](https://github.com/supabase/auth/issues/1868)) ([483463e](https://github.com/supabase/auth/commit/483463e49eec7b2974cca05eadca6b933b2145b5)) +* handle user banned error code ([#1851](https://github.com/supabase/auth/issues/1851)) ([a6918f4](https://github.com/supabase/auth/commit/a6918f49baee42899b3ae1b7b6bc126d84629c99)) +* Revert "fix: revert fallback on btree indexes when hash is unavailable" ([#1859](https://github.com/supabase/auth/issues/1859)) ([9fe5b1e](https://github.com/supabase/auth/commit/9fe5b1eebfafb385d6b5d10196aeb2a1964ab296)) +* skip cleanup for non-2xx status ([#1877](https://github.com/supabase/auth/issues/1877)) ([f572ced](https://github.com/supabase/auth/commit/f572ced3699c7f920deccce1a3539299541ec94c)) + +## [2.165.1](https://github.com/supabase/auth/compare/v2.165.0...v2.165.1) (2024-12-06) + + +### Bug Fixes + +* allow setting the mailer service headers as strings ([#1861](https://github.com/supabase/auth/issues/1861)) ([7907b56](https://github.com/supabase/auth/commit/7907b566228f7e2d76049b44cfe0cc808c109100)) + +## [2.165.0](https://github.com/supabase/auth/compare/v2.164.0...v2.165.0) (2024-12-05) + + +### Features + +* add email validation function to lower bounce rates ([#1845](https://github.com/supabase/auth/issues/1845)) ([2c291f0](https://github.com/supabase/auth/commit/2c291f0356f3e91063b6b43bf2a21625b0ce0ebd)) +* use embedded migrations for `migrate` command ([#1843](https://github.com/supabase/auth/issues/1843)) ([e358da5](https://github.com/supabase/auth/commit/e358da5f0e267725a77308461d0a4126436fc537)) + + +### Bug Fixes + +* fallback on btree indexes when hash is unavailable ([#1856](https://github.com/supabase/auth/issues/1856)) ([b33bc31](https://github.com/supabase/auth/commit/b33bc31c07549dc9dc221100995d6f6b6754fd3a)) +* return the error code instead of status code ([#1855](https://github.com/supabase/auth/issues/1855)) ([834a380](https://github.com/supabase/auth/commit/834a380d803ae9ce59ce5ee233fa3a78a984fe68)) +* revert fallback on btree indexes when hash is unavailable ([#1858](https://github.com/supabase/auth/issues/1858)) ([1c7202f](https://github.com/supabase/auth/commit/1c7202ff835856562ee66b33be131eca769acf1d)) +* update ip mismatch error message ([#1849](https://github.com/supabase/auth/issues/1849)) ([49fbbf0](https://github.com/supabase/auth/commit/49fbbf03917a1085c58e9a1ff76c247ae6bb9ca7)) + +## [2.164.0](https://github.com/supabase/auth/compare/v2.163.2...v2.164.0) (2024-11-13) + + +### Features + +* return validation failed error if captcha request was not json ([#1815](https://github.com/supabase/auth/issues/1815)) ([26d2e36](https://github.com/supabase/auth/commit/26d2e36bba29eb8a6ddba556acfd0820f3bfde5d)) + + +### Bug Fixes + +* add error codes to refresh token flow ([#1824](https://github.com/supabase/auth/issues/1824)) ([4614dc5](https://github.com/supabase/auth/commit/4614dc54ab1dcb5390cfed05441e7888af017d92)) +* add test coverage for rate limits with 0 permitted events ([#1834](https://github.com/supabase/auth/issues/1834)) ([7c3cf26](https://github.com/supabase/auth/commit/7c3cf26cfe2a3e4de579d10509945186ad719855)) +* correct web authn aaguid column naming ([#1826](https://github.com/supabase/auth/issues/1826)) ([0a589d0](https://github.com/supabase/auth/commit/0a589d04e1cd9310cb260d329bc8beb050adf8da)) +* default to files:read scope for Figma provider ([#1831](https://github.com/supabase/auth/issues/1831)) ([9ce2857](https://github.com/supabase/auth/commit/9ce28570bf3da9571198d44d693c7ad7038cde33)) +* improve error messaging for http hooks ([#1821](https://github.com/supabase/auth/issues/1821)) ([fa020d0](https://github.com/supabase/auth/commit/fa020d0fc292d5c381c57ecac6666d9ff657e4c4)) +* make drop_uniqueness_constraint_on_phone idempotent ([#1817](https://github.com/supabase/auth/issues/1817)) ([158e473](https://github.com/supabase/auth/commit/158e4732afa17620cdd89c85b7b57569feea5c21)) +* possible panic if refresh token has a null session_id ([#1822](https://github.com/supabase/auth/issues/1822)) ([a7129df](https://github.com/supabase/auth/commit/a7129df4e1d91a042b56ff1f041b9c6598825475)) +* rate limits of 0 take precedence over MAILER_AUTO_CONFIRM ([#1837](https://github.com/supabase/auth/issues/1837)) ([cb7894e](https://github.com/supabase/auth/commit/cb7894e1119d27d527dedcca22d8b3d433beddac)) + +## [2.163.2](https://github.com/supabase/auth/compare/v2.163.1...v2.163.2) (2024-10-22) + + +### Bug Fixes + +* ignore rate limits for autoconfirm ([#1810](https://github.com/supabase/auth/issues/1810)) ([9ce2340](https://github.com/supabase/auth/commit/9ce23409f960a8efa55075931138624cb681eca5)) + +## [2.163.1](https://github.com/supabase/auth/compare/v2.163.0...v2.163.1) (2024-10-22) + + +### Bug Fixes + +* external host validation ([#1808](https://github.com/supabase/auth/issues/1808)) ([4f6a461](https://github.com/supabase/auth/commit/4f6a4617074e61ba3b31836ccb112014904ce97c)), closes [#1228](https://github.com/supabase/auth/issues/1228) + +## [2.163.0](https://github.com/supabase/auth/compare/v2.162.2...v2.163.0) (2024-10-15) + + +### Features + +* add mail header support via `GOTRUE_SMTP_HEADERS` with `$messageType` ([#1804](https://github.com/supabase/auth/issues/1804)) ([99d6a13](https://github.com/supabase/auth/commit/99d6a134c44554a8ad06695e1dff54c942c8335d)) +* add MFA for WebAuthn ([#1775](https://github.com/supabase/auth/issues/1775)) ([8cc2f0e](https://github.com/supabase/auth/commit/8cc2f0e14d06d0feb56b25a0278fda9e213b6b5a)) +* configurable email and sms rate limiting ([#1800](https://github.com/supabase/auth/issues/1800)) ([5e94047](https://github.com/supabase/auth/commit/5e9404717e1c962ab729cde150ef5b40ea31a6e8)) +* mailer logging ([#1805](https://github.com/supabase/auth/issues/1805)) ([9354b83](https://github.com/supabase/auth/commit/9354b83a48a3edcb49197c997a1e96efc80c5383)) +* preserve rate limiters in memory across configuration reloads ([#1792](https://github.com/supabase/auth/issues/1792)) ([0a3968b](https://github.com/supabase/auth/commit/0a3968b02b9f044bfb7e5ebc71dca970d2bb7807)) + + +### Bug Fixes + +* add twilio verify support on mfa ([#1714](https://github.com/supabase/auth/issues/1714)) ([aeb5d8f](https://github.com/supabase/auth/commit/aeb5d8f8f18af60ce369cab5714979ac0c208308)) +* email header setting no longer misleading ([#1802](https://github.com/supabase/auth/issues/1802)) ([3af03be](https://github.com/supabase/auth/commit/3af03be6b65c40f3f4f62ce9ab989a20d75ae53a)) +* enforce authorized address checks on send email only ([#1806](https://github.com/supabase/auth/issues/1806)) ([c0c5b23](https://github.com/supabase/auth/commit/c0c5b23728c8fb633dae23aa4b29ed60e2691a2b)) +* fix `getExcludedColumns` slice allocation ([#1788](https://github.com/supabase/auth/issues/1788)) ([7f006b6](https://github.com/supabase/auth/commit/7f006b63c8d7e28e55a6d471881e9c118df80585)) +* Fix reqPath for bypass check for verify EP ([#1789](https://github.com/supabase/auth/issues/1789)) ([646dc66](https://github.com/supabase/auth/commit/646dc66ea8d59a7f78bf5a5e55d9b5065a718c23)) +* inline mailme package for easy development ([#1803](https://github.com/supabase/auth/issues/1803)) ([fa6f729](https://github.com/supabase/auth/commit/fa6f729a027eff551db104550fa626088e00bc15)) + +## [2.162.2](https://github.com/supabase/auth/compare/v2.162.1...v2.162.2) (2024-10-05) + + +### Bug Fixes + +* refactor mfa validation into functions ([#1780](https://github.com/supabase/auth/issues/1780)) ([410b8ac](https://github.com/supabase/auth/commit/410b8acdd659fc4c929fe57a9e9dba4c76da305d)) +* upgrade ci Go version ([#1782](https://github.com/supabase/auth/issues/1782)) ([97a48f6](https://github.com/supabase/auth/commit/97a48f6daaa2edda5b568939cbb1007ccdf33cfc)) +* validateEmail should normalise emails ([#1790](https://github.com/supabase/auth/issues/1790)) ([2e9b144](https://github.com/supabase/auth/commit/2e9b144a0cbf2d26d3c4c2eafbff1899a36aeb3b)) + +## [2.162.1](https://github.com/supabase/auth/compare/v2.162.0...v2.162.1) (2024-10-03) + + +### Bug Fixes + +* bypass check for token & verify endpoints ([#1785](https://github.com/supabase/auth/issues/1785)) ([9ac2ea0](https://github.com/supabase/auth/commit/9ac2ea0180826cd2f65e679524aabfb10666e973)) + +## [2.162.0](https://github.com/supabase/auth/compare/v2.161.0...v2.162.0) (2024-09-27) + + +### Features + +* add support for migration of firebase scrypt passwords ([#1768](https://github.com/supabase/auth/issues/1768)) ([ba00f75](https://github.com/supabase/auth/commit/ba00f75c28d6708ddf8ee151ce18f2d6193689ef)) + + +### Bug Fixes + +* apply authorized email restriction to non-admin routes ([#1778](https://github.com/supabase/auth/issues/1778)) ([1af203f](https://github.com/supabase/auth/commit/1af203f92372e6db12454a0d319aad8ce3d149e7)) +* magiclink failing due to passwordStrength check ([#1769](https://github.com/supabase/auth/issues/1769)) ([7a5411f](https://github.com/supabase/auth/commit/7a5411f1d4247478f91027bc4969cbbe95b7774c)) + +## [2.161.0](https://github.com/supabase/auth/compare/v2.160.0...v2.161.0) (2024-09-24) + + +### Features + +* add `x-sb-error-code` header, show error code in logs ([#1765](https://github.com/supabase/auth/issues/1765)) ([ed91c59](https://github.com/supabase/auth/commit/ed91c59aa332738bd0ac4b994aeec2cdf193a068)) +* add webauthn configuration variables ([#1773](https://github.com/supabase/auth/issues/1773)) ([77d5897](https://github.com/supabase/auth/commit/77d58976ae624dbb7f8abee041dd4557aab81109)) +* config reloading ([#1771](https://github.com/supabase/auth/issues/1771)) ([6ee0091](https://github.com/supabase/auth/commit/6ee009163bfe451e2a0b923705e073928a12c004)) + + +### Bug Fixes + +* add additional information around errors for missing content type header ([#1576](https://github.com/supabase/auth/issues/1576)) ([c2b2f96](https://github.com/supabase/auth/commit/c2b2f96f07c97c15597cd972b1cd672238d87cdc)) +* add token to hook payload for non-secure email change ([#1763](https://github.com/supabase/auth/issues/1763)) ([7e472ad](https://github.com/supabase/auth/commit/7e472ad72042e86882dab3fddce9fafa66a8236c)) +* update aal requirements to update user ([#1766](https://github.com/supabase/auth/issues/1766)) ([25d9874](https://github.com/supabase/auth/commit/25d98743f6cc2cca2b490a087f468c8556ec5e44)) +* update mfa admin methods ([#1774](https://github.com/supabase/auth/issues/1774)) ([567ea7e](https://github.com/supabase/auth/commit/567ea7ebd18eacc5e6daea8adc72e59e94459991)) +* user sanitization should clean up email change info too ([#1759](https://github.com/supabase/auth/issues/1759)) ([9d419b4](https://github.com/supabase/auth/commit/9d419b400f0637b10e5c235b8fd5bac0d69352bd)) + +## [2.160.0](https://github.com/supabase/auth/compare/v2.159.2...v2.160.0) (2024-09-02) + + +### Features + +* add authorized email address support ([#1757](https://github.com/supabase/auth/issues/1757)) ([f3a28d1](https://github.com/supabase/auth/commit/f3a28d182d193cf528cc72a985dfeaf7ecb67056)) +* add option to disable magic links ([#1756](https://github.com/supabase/auth/issues/1756)) ([2ad0737](https://github.com/supabase/auth/commit/2ad07373aa9239eba94abdabbb01c9abfa8c48de)) +* add support for saml encrypted assertions ([#1752](https://github.com/supabase/auth/issues/1752)) ([c5480ef](https://github.com/supabase/auth/commit/c5480ef83248ec2e7e3d3d87f92f43f17161ed25)) + + +### Bug Fixes + +* apply shared limiters before email / sms is sent ([#1748](https://github.com/supabase/auth/issues/1748)) ([bf276ab](https://github.com/supabase/auth/commit/bf276ab49753642793471815727559172fea4efc)) +* simplify WaitForCleanup ([#1747](https://github.com/supabase/auth/issues/1747)) ([0084625](https://github.com/supabase/auth/commit/0084625ad0790dd7c14b412d932425f4b84bb4c8)) + +## [2.159.2](https://github.com/supabase/auth/compare/v2.159.1...v2.159.2) (2024-08-28) + + +### Bug Fixes + +* allow anonymous user to update password ([#1739](https://github.com/supabase/auth/issues/1739)) ([2d51956](https://github.com/supabase/auth/commit/2d519569d7b8540886d0a64bf3e561ef5f91eb63)) +* hide hook name ([#1743](https://github.com/supabase/auth/issues/1743)) ([7e38f4c](https://github.com/supabase/auth/commit/7e38f4cf37768fe2adf92bbd0723d1d521b3d74c)) +* remove server side cookie token methods ([#1742](https://github.com/supabase/auth/issues/1742)) ([c6efec4](https://github.com/supabase/auth/commit/c6efec4cbc950e01e1fd06d45ed821bd27c2ad08)) + +## [2.159.1](https://github.com/supabase/auth/compare/v2.159.0...v2.159.1) (2024-08-23) + + +### Bug Fixes + +* return oauth identity when user is created ([#1736](https://github.com/supabase/auth/issues/1736)) ([60cfb60](https://github.com/supabase/auth/commit/60cfb6063afa574dfe4993df6b0e087d4df71309)) + +## [2.159.0](https://github.com/supabase/auth/compare/v2.158.1...v2.159.0) (2024-08-21) + + +### Features + +* Vercel marketplace OIDC ([#1731](https://github.com/supabase/auth/issues/1731)) ([a9ff361](https://github.com/supabase/auth/commit/a9ff3612196af4a228b53a8bfb9c11785bcfba8d)) + + +### Bug Fixes + +* add error codes to password login flow ([#1721](https://github.com/supabase/auth/issues/1721)) ([4351226](https://github.com/supabase/auth/commit/435122627a0784f1c5cb76d7e08caa1f6259423b)) +* change phone constraint to per user ([#1713](https://github.com/supabase/auth/issues/1713)) ([b9bc769](https://github.com/supabase/auth/commit/b9bc769b93b6e700925fcbc1ebf8bf9678034205)) +* custom SMS does not work with Twilio Verify ([#1733](https://github.com/supabase/auth/issues/1733)) ([dc2391d](https://github.com/supabase/auth/commit/dc2391d15f2c0725710aa388cd32a18797e6769c)) +* ignore errors if transaction has closed already ([#1726](https://github.com/supabase/auth/issues/1726)) ([53c11d1](https://github.com/supabase/auth/commit/53c11d173a79ae5c004871b1b5840c6f9425a080)) +* redirect invalid state errors to site url ([#1722](https://github.com/supabase/auth/issues/1722)) ([b2b1123](https://github.com/supabase/auth/commit/b2b11239dc9f9bd3c85d76f6c23ee94beb3330bb)) +* remove TOTP field for phone enroll response ([#1717](https://github.com/supabase/auth/issues/1717)) ([4b04327](https://github.com/supabase/auth/commit/4b043275dd2d94600a8138d4ebf4638754ed926b)) +* use signing jwk to sign oauth state ([#1728](https://github.com/supabase/auth/issues/1728)) ([66fd0c8](https://github.com/supabase/auth/commit/66fd0c8434388bbff1e1bf02f40517aca0e9d339)) + +## [2.158.1](https://github.com/supabase/auth/compare/v2.158.0...v2.158.1) (2024-08-05) + + +### Bug Fixes + +* add last_challenged_at field to mfa factors ([#1705](https://github.com/supabase/auth/issues/1705)) ([29cbeb7](https://github.com/supabase/auth/commit/29cbeb799ff35ce528bfbd01b7103a24903d8061)) +* allow enabling sms hook without setting up sms provider ([#1704](https://github.com/supabase/auth/issues/1704)) ([575e88a](https://github.com/supabase/auth/commit/575e88ac345adaeb76ab6aae077307fdab9cda3c)) +* drop the MFA_ENABLED config ([#1701](https://github.com/supabase/auth/issues/1701)) ([078c3a8](https://github.com/supabase/auth/commit/078c3a8adcd51e57b68ab1b582549f5813cccd14)) +* enforce uniqueness on verified phone numbers ([#1693](https://github.com/supabase/auth/issues/1693)) ([70446cc](https://github.com/supabase/auth/commit/70446cc11d70b0493d742fe03f272330bb5b633e)) +* expose `X-Supabase-Api-Version` header in CORS ([#1612](https://github.com/supabase/auth/issues/1612)) ([6ccd814](https://github.com/supabase/auth/commit/6ccd814309dca70a9e3585543887194b05d725d3)) +* include factor_id in query ([#1702](https://github.com/supabase/auth/issues/1702)) ([ac14e82](https://github.com/supabase/auth/commit/ac14e82b33545466184da99e99b9d3fe5f3876d9)) +* move is owned by check to load factor ([#1703](https://github.com/supabase/auth/issues/1703)) ([701a779](https://github.com/supabase/auth/commit/701a779cf092e777dd4ad4954dc650164b09ab32)) +* refactor TOTP MFA into separate methods ([#1698](https://github.com/supabase/auth/issues/1698)) ([250d92f](https://github.com/supabase/auth/commit/250d92f9a18d38089d1bf262ef9088022a446965)) +* remove check for content-length ([#1700](https://github.com/supabase/auth/issues/1700)) ([81b332d](https://github.com/supabase/auth/commit/81b332d2f48622008469d2c5a9b130465a65f2a3)) +* remove FindFactorsByUser ([#1707](https://github.com/supabase/auth/issues/1707)) ([af8e2dd](https://github.com/supabase/auth/commit/af8e2dda15a1234a05e7d2d34d316eaa029e0912)) +* update openapi spec for MFA (Phone) ([#1689](https://github.com/supabase/auth/issues/1689)) ([a3da4b8](https://github.com/supabase/auth/commit/a3da4b89820c37f03ea128889616aca598d99f68)) + +## [2.158.0](https://github.com/supabase/auth/compare/v2.157.0...v2.158.0) (2024-07-31) + + +### Features + +* add hook log entry with `run_hook` action ([#1684](https://github.com/supabase/auth/issues/1684)) ([46491b8](https://github.com/supabase/auth/commit/46491b867a4f5896494417391392a373a453fa5f)) +* MFA (Phone) ([#1668](https://github.com/supabase/auth/issues/1668)) ([ae091aa](https://github.com/supabase/auth/commit/ae091aa942bdc5bc97481037508ec3bb4079d859)) + + +### Bug Fixes + +* maintain backward compatibility for asymmetric JWTs ([#1690](https://github.com/supabase/auth/issues/1690)) ([0ad1402](https://github.com/supabase/auth/commit/0ad1402444348e47e1e42be186b3f052d31be824)) +* MFA NewFactor to default to creating unverfied factors ([#1692](https://github.com/supabase/auth/issues/1692)) ([3d448fa](https://github.com/supabase/auth/commit/3d448fa73cb77eb8511dbc47bfafecce4a4a2150)) +* minor spelling errors ([#1688](https://github.com/supabase/auth/issues/1688)) ([6aca52b](https://github.com/supabase/auth/commit/6aca52b56f8a6254de7709c767b9a5649f1da248)), closes [#1682](https://github.com/supabase/auth/issues/1682) +* treat `GOTRUE_MFA_ENABLED` as meaning TOTP enabled on enroll and verify ([#1694](https://github.com/supabase/auth/issues/1694)) ([8015251](https://github.com/supabase/auth/commit/8015251400bd52cbdad3ea28afb83b1cdfe816dd)) +* update mfa phone migration to be idempotent ([#1687](https://github.com/supabase/auth/issues/1687)) ([fdff1e7](https://github.com/supabase/auth/commit/fdff1e703bccf93217636266f1862bd0a9205edb)) + +## [2.157.0](https://github.com/supabase/auth/compare/v2.156.0...v2.157.0) (2024-07-26) + + +### Features + +* add asymmetric jwt support ([#1674](https://github.com/supabase/auth/issues/1674)) ([c7a2be3](https://github.com/supabase/auth/commit/c7a2be347b301b666e99adc3d3fed78c5e287c82)) + +## [2.156.0](https://github.com/supabase/auth/compare/v2.155.6...v2.156.0) (2024-07-25) + + +### Features + +* add is_anonymous claim to Auth hook jsonschema ([#1667](https://github.com/supabase/auth/issues/1667)) ([f9df65c](https://github.com/supabase/auth/commit/f9df65c91e226084abfa2e868ab6bab892d16d2f)) + + +### Bug Fixes + +* restrict autoconfirm email change to anonymous users ([#1679](https://github.com/supabase/auth/issues/1679)) ([b57e223](https://github.com/supabase/auth/commit/b57e2230102280ed873acf70be1aeb5a2f6f7a4f)) + +## [2.155.6](https://github.com/supabase/auth/compare/v2.155.5...v2.155.6) (2024-07-22) + + +### Bug Fixes + +* use deep equal ([#1672](https://github.com/supabase/auth/issues/1672)) ([8efd57d](https://github.com/supabase/auth/commit/8efd57dab40346762a04bac61b314ce05d6fa69c)) + +## [2.155.5](https://github.com/supabase/auth/compare/v2.155.4...v2.155.5) (2024-07-19) + + +### Bug Fixes + +* check password max length in checkPasswordStrength ([#1659](https://github.com/supabase/auth/issues/1659)) ([1858c93](https://github.com/supabase/auth/commit/1858c93bba6f5bc41e4c65489f12c1a0786a1f2b)) +* don't update attribute mapping if nil ([#1665](https://github.com/supabase/auth/issues/1665)) ([7e67f3e](https://github.com/supabase/auth/commit/7e67f3edbf81766df297a66f52a8e472583438c6)) +* refactor mfa models and add observability to loadFactor ([#1669](https://github.com/supabase/auth/issues/1669)) ([822fb93](https://github.com/supabase/auth/commit/822fb93faab325ba3d4bb628dff43381d68d0b5d)) + +## [2.155.4](https://github.com/supabase/auth/compare/v2.155.3...v2.155.4) (2024-07-17) + + +### Bug Fixes + +* treat empty string as nil in `encrypted_password` ([#1663](https://github.com/supabase/auth/issues/1663)) ([f99286e](https://github.com/supabase/auth/commit/f99286eaed505daf3db6f381265ef6024e7e36d2)) + +## [2.155.3](https://github.com/supabase/auth/compare/v2.155.2...v2.155.3) (2024-07-12) + + +### Bug Fixes + +* serialize jwt as string ([#1657](https://github.com/supabase/auth/issues/1657)) ([98d8324](https://github.com/supabase/auth/commit/98d83245e40d606438eb0afdbf474276179fd91d)) + +## [2.155.2](https://github.com/supabase/auth/compare/v2.155.1...v2.155.2) (2024-07-12) + + +### Bug Fixes + +* improve session error logging ([#1655](https://github.com/supabase/auth/issues/1655)) ([5a6793e](https://github.com/supabase/auth/commit/5a6793ee8fce7a089750fe10b3b63bb0a19d6d21)) +* omit empty string from name & use case-insensitive equality for comparing SAML attributes ([#1654](https://github.com/supabase/auth/issues/1654)) ([bf5381a](https://github.com/supabase/auth/commit/bf5381a6b1c686955dc4e39fe5fb806ffd309563)) +* set rate limit log level to warn ([#1652](https://github.com/supabase/auth/issues/1652)) ([10ca9c8](https://github.com/supabase/auth/commit/10ca9c806e4b67a371897f1b3f93c515764c4240)) + +## [2.155.1](https://github.com/supabase/auth/compare/v2.155.0...v2.155.1) (2024-07-04) + + +### Bug Fixes + +* apply mailer autoconfirm config to update user email ([#1646](https://github.com/supabase/auth/issues/1646)) ([a518505](https://github.com/supabase/auth/commit/a5185058e72509b0781e0eb59910ecdbb8676fee)) +* check for empty aud string ([#1649](https://github.com/supabase/auth/issues/1649)) ([42c1d45](https://github.com/supabase/auth/commit/42c1d4526b98203664d4a22c23014ecd0b4951f9)) +* return proper error if sms rate limit is exceeded ([#1647](https://github.com/supabase/auth/issues/1647)) ([3c8d765](https://github.com/supabase/auth/commit/3c8d7656431ac4b2e80726b7c37adb8f0c778495)) + +## [2.155.0](https://github.com/supabase/auth/compare/v2.154.2...v2.155.0) (2024-07-03) + + +### Features + +* add `password_hash` and `id` fields to admin create user ([#1641](https://github.com/supabase/auth/issues/1641)) ([20d59f1](https://github.com/supabase/auth/commit/20d59f10b601577683d05bcd7d2128ff4bc462a0)) + + +### Bug Fixes + +* improve mfa verify logs ([#1635](https://github.com/supabase/auth/issues/1635)) ([d8b47f9](https://github.com/supabase/auth/commit/d8b47f9d3f0dc8f97ad1de49e45f452ebc726481)) +* invited users should have a temporary password generated ([#1644](https://github.com/supabase/auth/issues/1644)) ([3f70d9d](https://github.com/supabase/auth/commit/3f70d9d8974d0e9c437c51e1312ad17ce9056ec9)) +* upgrade golang-jwt to v5 ([#1639](https://github.com/supabase/auth/issues/1639)) ([2cb97f0](https://github.com/supabase/auth/commit/2cb97f080fa4695766985cc4792d09476534be68)) +* use pointer for `user.EncryptedPassword` ([#1637](https://github.com/supabase/auth/issues/1637)) ([bbecbd6](https://github.com/supabase/auth/commit/bbecbd61a46b0c528b1191f48d51f166c06f4b16)) + +## [2.154.2](https://github.com/supabase/auth/compare/v2.154.1...v2.154.2) (2024-06-24) + + +### Bug Fixes + +* publish to ghcr.io/supabase/auth ([#1626](https://github.com/supabase/auth/issues/1626)) ([930aa3e](https://github.com/supabase/auth/commit/930aa3edb633823d4510c2aff675672df06f1211)), closes [#1625](https://github.com/supabase/auth/issues/1625) +* revert define search path in auth functions ([#1634](https://github.com/supabase/auth/issues/1634)) ([155e87e](https://github.com/supabase/auth/commit/155e87ef8129366d665968f64d1fc66676d07e16)) +* update MaxFrequency error message to reflect number of seconds ([#1540](https://github.com/supabase/auth/issues/1540)) ([e81c25d](https://github.com/supabase/auth/commit/e81c25d19551fdebfc5197d96bc220ddb0f8227b)) + +## [2.154.1](https://github.com/supabase/auth/compare/v2.154.0...v2.154.1) (2024-06-17) + + +### Bug Fixes + +* add ip based limiter ([#1622](https://github.com/supabase/auth/issues/1622)) ([06464c0](https://github.com/supabase/auth/commit/06464c013571253d1f18f7ae5e840826c4bd84a7)) +* admin user update should update is_anonymous field ([#1623](https://github.com/supabase/auth/issues/1623)) ([f5c6fcd](https://github.com/supabase/auth/commit/f5c6fcd9c3fee0f793f96880a8caebc5b5cb0916)) + +## [2.154.0](https://github.com/supabase/auth/compare/v2.153.0...v2.154.0) (2024-06-12) + + +### Features + +* add max length check for email ([#1508](https://github.com/supabase/auth/issues/1508)) ([f9c13c0](https://github.com/supabase/auth/commit/f9c13c0ad5c556bede49d3e0f6e5f58ca26161c3)) +* add support for Slack OAuth V2 ([#1591](https://github.com/supabase/auth/issues/1591)) ([bb99251](https://github.com/supabase/auth/commit/bb992519cdf7578dc02cd7de55e2e6aa09b4c0f3)) +* encrypt sensitive columns ([#1593](https://github.com/supabase/auth/issues/1593)) ([e4a4758](https://github.com/supabase/auth/commit/e4a475820b2dc1f985bd37df15a8ab9e781626f5)) +* upgrade otel to v1.26 ([#1585](https://github.com/supabase/auth/issues/1585)) ([cdd13ad](https://github.com/supabase/auth/commit/cdd13adec02eb0c9401bc55a2915c1005d50dea1)) +* use largest avatar from spotify instead ([#1210](https://github.com/supabase/auth/issues/1210)) ([4f9994b](https://github.com/supabase/auth/commit/4f9994bf792c3887f2f45910b11a9c19ee3a896b)), closes [#1209](https://github.com/supabase/auth/issues/1209) + + +### Bug Fixes + +* define search path in auth functions ([#1616](https://github.com/supabase/auth/issues/1616)) ([357bda2](https://github.com/supabase/auth/commit/357bda23cb2abd12748df80a9d27288aa548534d)) +* enable rls & update grants for auth tables ([#1617](https://github.com/supabase/auth/issues/1617)) ([28967aa](https://github.com/supabase/auth/commit/28967aa4b5db2363cc581c9da0d64e974eb7b64c)) + +## [2.153.0](https://github.com/supabase/auth/compare/v2.152.0...v2.153.0) (2024-06-04) + + +### Features + +* add SAML specific external URL config ([#1599](https://github.com/supabase/auth/issues/1599)) ([b352719](https://github.com/supabase/auth/commit/b3527190560381fafe9ba2fae4adc3b73703024a)) +* add support for verifying argon2i and argon2id passwords ([#1597](https://github.com/supabase/auth/issues/1597)) ([55409f7](https://github.com/supabase/auth/commit/55409f797bea55068a3fafdddd6cfdb78feba1b4)) +* make the email client explicity set the format to be HTML ([#1149](https://github.com/supabase/auth/issues/1149)) ([53e223a](https://github.com/supabase/auth/commit/53e223abdf29f4abcad13f99baf00daedcb00c3f)) + + +### Bug Fixes + +* call write header in write if not written ([#1598](https://github.com/supabase/auth/issues/1598)) ([0ef7eb3](https://github.com/supabase/auth/commit/0ef7eb30619d4c365e06a94a79b9cb0333d792da)) +* deadlock issue with timeout middleware write ([#1595](https://github.com/supabase/auth/issues/1595)) ([6c9fbd4](https://github.com/supabase/auth/commit/6c9fbd4bd5623c729906fca7857ab508166a3056)) +* improve token OIDC logging ([#1606](https://github.com/supabase/auth/issues/1606)) ([5262683](https://github.com/supabase/auth/commit/526268311844467664e89c8329e5aaee817dbbaf)) +* update contributing to use v1.22 ([#1609](https://github.com/supabase/auth/issues/1609)) ([5894d9e](https://github.com/supabase/auth/commit/5894d9e41e7681512a9904ad47082a705e948c98)) + +## [2.152.0](https://github.com/supabase/auth/compare/v2.151.0...v2.152.0) (2024-05-22) + + +### Features + +* new timeout writer implementation ([#1584](https://github.com/supabase/auth/issues/1584)) ([72614a1](https://github.com/supabase/auth/commit/72614a1fce27888f294772b512f8e31c55a36d87)) +* remove legacy lookup in users for one_time_tokens (phase II) ([#1569](https://github.com/supabase/auth/issues/1569)) ([39ca026](https://github.com/supabase/auth/commit/39ca026035f6c61d206d31772c661b326c2a424c)) +* update chi version ([#1581](https://github.com/supabase/auth/issues/1581)) ([c64ae3d](https://github.com/supabase/auth/commit/c64ae3dd775e8fb3022239252c31b4ee73893237)) +* update openapi spec with identity and is_anonymous fields ([#1573](https://github.com/supabase/auth/issues/1573)) ([86a79df](https://github.com/supabase/auth/commit/86a79df9ecfcf09fda0b8e07afbc41154fbb7d9d)) + + +### Bug Fixes + +* improve logging structure ([#1583](https://github.com/supabase/auth/issues/1583)) ([c22fc15](https://github.com/supabase/auth/commit/c22fc15d2a8383e95a2364f383dfa7dce5f5df88)) +* sms verify should update is_anonymous field ([#1580](https://github.com/supabase/auth/issues/1580)) ([e5f98cb](https://github.com/supabase/auth/commit/e5f98cb9e24ecebb0b7dc88c495fd456cc73fcba)) +* use api_external_url domain as localname ([#1575](https://github.com/supabase/auth/issues/1575)) ([ed2b490](https://github.com/supabase/auth/commit/ed2b4907244281e4c54aaef74b1f4c8a8e3d97c9)) + +## [2.151.0](https://github.com/supabase/auth/compare/v2.150.1...v2.151.0) (2024-05-06) + + +### Features + +* refactor one-time tokens for performance ([#1558](https://github.com/supabase/auth/issues/1558)) ([d1cf8d9](https://github.com/supabase/auth/commit/d1cf8d9096e9183d7772b73031de8ecbd66e912b)) + + +### Bug Fixes + +* do call send sms hook when SMS autoconfirm is enabled ([#1562](https://github.com/supabase/auth/issues/1562)) ([bfe4d98](https://github.com/supabase/auth/commit/bfe4d988f3768b0407526bcc7979fb21d8cbebb3)) +* format test otps ([#1567](https://github.com/supabase/auth/issues/1567)) ([434a59a](https://github.com/supabase/auth/commit/434a59ae387c35fd6629ec7c674d439537e344e5)) +* log final writer error instead of handling ([#1564](https://github.com/supabase/auth/issues/1564)) ([170bd66](https://github.com/supabase/auth/commit/170bd6615405afc852c7107f7358dfc837bad737)) + +## [2.150.1](https://github.com/supabase/auth/compare/v2.150.0...v2.150.1) (2024-04-28) + + +### Bug Fixes + +* add db conn max idle time setting ([#1555](https://github.com/supabase/auth/issues/1555)) ([2caa7b4](https://github.com/supabase/auth/commit/2caa7b4d75d2ff54af20f3e7a30a8eeec8cbcda9)) + +## [2.150.0](https://github.com/supabase/auth/compare/v2.149.0...v2.150.0) (2024-04-25) + + +### Features + +* add support for Azure CIAM login ([#1541](https://github.com/supabase/auth/issues/1541)) ([1cb4f96](https://github.com/supabase/auth/commit/1cb4f96bdc7ef3ef995781b4cf3c4364663a2bf3)) +* add timeout middleware ([#1529](https://github.com/supabase/auth/issues/1529)) ([f96ff31](https://github.com/supabase/auth/commit/f96ff31040b28e3a7373b4fd41b7334eda1b413e)) +* allow for postgres and http functions on each extensibility point ([#1528](https://github.com/supabase/auth/issues/1528)) ([348a1da](https://github.com/supabase/auth/commit/348a1daee24f6e44b14c018830b748e46d34b4c2)) +* merge provider metadata on link account ([#1552](https://github.com/supabase/auth/issues/1552)) ([bd8b5c4](https://github.com/supabase/auth/commit/bd8b5c41dd544575e1a52ccf1ef3f0fdee67458c)) +* send over user in SendSMS Hook instead of UserID ([#1551](https://github.com/supabase/auth/issues/1551)) ([d4d743c](https://github.com/supabase/auth/commit/d4d743c2ae9490e1b3249387e3b0d60df6913c68)) + + +### Bug Fixes + +* return error if session id does not exist ([#1538](https://github.com/supabase/auth/issues/1538)) ([91e9eca](https://github.com/supabase/auth/commit/91e9ecabe33a1c022f8e82a6050c22a7ca42de48)) + +## [2.149.0](https://github.com/supabase/auth/compare/v2.148.0...v2.149.0) (2024-04-15) + + +### Features + +* refactor generate accesss token to take in request ([#1531](https://github.com/supabase/auth/issues/1531)) ([e4f2b59](https://github.com/supabase/auth/commit/e4f2b59e8e1f8158b6461a384349f1a32cc1bf9a)) + + +### Bug Fixes + +* linkedin_oidc provider error ([#1534](https://github.com/supabase/auth/issues/1534)) ([4f5e8e5](https://github.com/supabase/auth/commit/4f5e8e5120531e5a103fbdda91b51cabcb4e1a8c)) +* revert patch for linkedin_oidc provider error ([#1535](https://github.com/supabase/auth/issues/1535)) ([58ef4af](https://github.com/supabase/auth/commit/58ef4af0b4224b78cd9e59428788d16a8d31e562)) +* update linkedin issuer url ([#1536](https://github.com/supabase/auth/issues/1536)) ([10d6d8b](https://github.com/supabase/auth/commit/10d6d8b1eafa504da2b2a351d1f64a3a832ab1b9)) + +## [2.148.0](https://github.com/supabase/auth/compare/v2.147.1...v2.148.0) (2024-04-10) + + +### Features + +* add array attribute mapping for SAML ([#1526](https://github.com/supabase/auth/issues/1526)) ([7326285](https://github.com/supabase/auth/commit/7326285c8af5c42e5c0c2d729ab224cf33ac3a1f)) + +## [2.147.1](https://github.com/supabase/auth/compare/v2.147.0...v2.147.1) (2024-04-09) + + +### Bug Fixes + +* add validation and proper decoding on send email hook ([#1520](https://github.com/supabase/auth/issues/1520)) ([e19e762](https://github.com/supabase/auth/commit/e19e762e3e29729a1d1164c65461427822cc87f1)) +* remove deprecated LogoutAllRefreshTokens ([#1519](https://github.com/supabase/auth/issues/1519)) ([35533ea](https://github.com/supabase/auth/commit/35533ea100669559e1209ecc7b091db3657234d9)) + +## [2.147.0](https://github.com/supabase/auth/compare/v2.146.0...v2.147.0) (2024-04-05) + + +### Features + +* add send email Hook ([#1512](https://github.com/supabase/auth/issues/1512)) ([cf42e02](https://github.com/supabase/auth/commit/cf42e02ec63779f52b1652a7413f64994964c82d)) + +## [2.146.0](https://github.com/supabase/auth/compare/v2.145.0...v2.146.0) (2024-04-03) + + +### Features + +* add custom sms hook ([#1474](https://github.com/supabase/auth/issues/1474)) ([0f6b29a](https://github.com/supabase/auth/commit/0f6b29a46f1dcbf92aa1f7cb702f42e7640f5f93)) +* forbid generating an access token without a session ([#1504](https://github.com/supabase/auth/issues/1504)) ([795e93d](https://github.com/supabase/auth/commit/795e93d0afbe94bcd78489a3319a970b7bf8e8bc)) + + +### Bug Fixes + +* add cleanup statement for anonymous users ([#1497](https://github.com/supabase/auth/issues/1497)) ([cf2372a](https://github.com/supabase/auth/commit/cf2372a177796b829b72454e7491ce768bf5a42f)) +* generate signup link should not error ([#1514](https://github.com/supabase/auth/issues/1514)) ([4fc3881](https://github.com/supabase/auth/commit/4fc388186ac7e7a9a32ca9b963a83d6ac2eb7603)) +* move all EmailActionTypes to mailer package ([#1510](https://github.com/supabase/auth/issues/1510)) ([765db08](https://github.com/supabase/auth/commit/765db08582669a1b7f054217fa8f0ed45804c0b5)) +* refactor mfa and aal update methods ([#1503](https://github.com/supabase/auth/issues/1503)) ([31a5854](https://github.com/supabase/auth/commit/31a585429bf248aa919d94c82c7c9e0c1c695461)) +* rename from CustomSMSProvider to SendSMS ([#1513](https://github.com/supabase/auth/issues/1513)) ([c0bc37b](https://github.com/supabase/auth/commit/c0bc37b44effaebb62ba85102f072db07fe57e48)) + +## [2.145.0](https://github.com/supabase/gotrue/compare/v2.144.0...v2.145.0) (2024-03-26) + + +### Features + +* add error codes ([#1377](https://github.com/supabase/gotrue/issues/1377)) ([e4beea1](https://github.com/supabase/gotrue/commit/e4beea1cdb80544b0581f1882696a698fdf64938)) +* add kakao OIDC ([#1381](https://github.com/supabase/gotrue/issues/1381)) ([b5566e7](https://github.com/supabase/gotrue/commit/b5566e7ac001cc9f2bac128de0fcb908caf3a5ed)) +* clean up expired factors ([#1371](https://github.com/supabase/gotrue/issues/1371)) ([5c94207](https://github.com/supabase/gotrue/commit/5c9420743a9aef0675f823c30aa4525b4933836e)) +* configurable NameID format for SAML provider ([#1481](https://github.com/supabase/gotrue/issues/1481)) ([ef405d8](https://github.com/supabase/gotrue/commit/ef405d89e69e008640f275bc37f8ec02ad32da40)) +* HTTP Hook - Add custom envconfig decoding for HTTP Hook Secrets ([#1467](https://github.com/supabase/gotrue/issues/1467)) ([5b24c4e](https://github.com/supabase/gotrue/commit/5b24c4eb05b2b52c4177d5f41cba30cb68495c8c)) +* refactor PKCE FlowState to reduce duplicate code ([#1446](https://github.com/supabase/gotrue/issues/1446)) ([b8d0337](https://github.com/supabase/gotrue/commit/b8d0337922c6712380f6dc74f7eac9fb71b1ae48)) + + +### Bug Fixes + +* add http support for https hooks on localhost ([#1484](https://github.com/supabase/gotrue/issues/1484)) ([5c04104](https://github.com/supabase/gotrue/commit/5c04104bf77a9c2db46d009764ec3ec3e484fc09)) +* cleanup panics due to bad inactivity timeout code ([#1471](https://github.com/supabase/gotrue/issues/1471)) ([548edf8](https://github.com/supabase/gotrue/commit/548edf898161c9ba9a136fc99ec2d52a8ba1f856)) +* **docs:** remove bracket on file name for broken link ([#1493](https://github.com/supabase/gotrue/issues/1493)) ([96f7a68](https://github.com/supabase/gotrue/commit/96f7a68a5479825e31106c2f55f82d5b2c007c0f)) +* impose expiry on auth code instead of magic link ([#1440](https://github.com/supabase/gotrue/issues/1440)) ([35aeaf1](https://github.com/supabase/gotrue/commit/35aeaf1b60dd27a22662a6d1955d60cc907b55dd)) +* invalidate email, phone OTPs on password change ([#1489](https://github.com/supabase/gotrue/issues/1489)) ([960a4f9](https://github.com/supabase/gotrue/commit/960a4f94f5500e33a0ec2f6afe0380bbc9562500)) +* move creation of flow state into function ([#1470](https://github.com/supabase/gotrue/issues/1470)) ([4392a08](https://github.com/supabase/gotrue/commit/4392a08d68d18828005d11382730117a7b143635)) +* prevent user email side-channel leak on verify ([#1472](https://github.com/supabase/gotrue/issues/1472)) ([311cde8](https://github.com/supabase/gotrue/commit/311cde8d1e82f823ae26a341e068034d60273864)) +* refactor email sending functions ([#1495](https://github.com/supabase/gotrue/issues/1495)) ([285c290](https://github.com/supabase/gotrue/commit/285c290adf231fea7ca1dff954491dc427cf18e2)) +* refactor factor_test to centralize setup ([#1473](https://github.com/supabase/gotrue/issues/1473)) ([c86007e](https://github.com/supabase/gotrue/commit/c86007e59684334b5e8c2285c36094b6eec89442)) +* refactor mfa challenge and tests ([#1469](https://github.com/supabase/gotrue/issues/1469)) ([6c76f21](https://github.com/supabase/gotrue/commit/6c76f21cee5dbef0562c37df6a546939affb2f8d)) +* Resend SMS when duplicate SMS sign ups are made ([#1490](https://github.com/supabase/gotrue/issues/1490)) ([73240a0](https://github.com/supabase/gotrue/commit/73240a0b096977703e3c7d24a224b5641ce47c81)) +* unlink identity bugs ([#1475](https://github.com/supabase/gotrue/issues/1475)) ([73e8d87](https://github.com/supabase/gotrue/commit/73e8d8742de3575b3165a707b5d2f486b2598d9d)) + +## [2.144.0](https://github.com/supabase/gotrue/compare/v2.143.0...v2.144.0) (2024-03-04) + + +### Features + +* add configuration for custom sms sender hook ([#1428](https://github.com/supabase/gotrue/issues/1428)) ([1ea56b6](https://github.com/supabase/gotrue/commit/1ea56b62d47edb0766d9e445406ecb43d387d920)) +* anonymous sign-ins ([#1460](https://github.com/supabase/gotrue/issues/1460)) ([130df16](https://github.com/supabase/gotrue/commit/130df165270c69c8e28aaa1b9421342f997c1ff3)) +* clean up test setup in MFA tests ([#1452](https://github.com/supabase/gotrue/issues/1452)) ([7185af8](https://github.com/supabase/gotrue/commit/7185af8de4a269cdde2629054d222333d3522ebe)) +* pass transaction to `invokeHook`, fixing pool exhaustion ([#1465](https://github.com/supabase/gotrue/issues/1465)) ([b536d36](https://github.com/supabase/gotrue/commit/b536d368f35adb31f937169e3f093d28352fa7be)) +* refactor resource owner password grant ([#1443](https://github.com/supabase/gotrue/issues/1443)) ([e63ad6f](https://github.com/supabase/gotrue/commit/e63ad6ff0f67d9a83456918a972ecb5109125628)) +* use dummy instance id to improve performance on refresh token queries ([#1454](https://github.com/supabase/gotrue/issues/1454)) ([656474e](https://github.com/supabase/gotrue/commit/656474e1b9ff3d5129190943e8c48e456625afe5)) + + +### Bug Fixes + +* expose `provider` under `amr` in access token ([#1456](https://github.com/supabase/gotrue/issues/1456)) ([e9f38e7](https://github.com/supabase/gotrue/commit/e9f38e76d8a7b93c5c2bb0de918a9b156155f018)) +* improve MFA QR Code resilience so as to support providers like 1Password ([#1455](https://github.com/supabase/gotrue/issues/1455)) ([6522780](https://github.com/supabase/gotrue/commit/652278046c9dd92f5cecd778735b058ef3fb41c7)) +* refactor request params to use generics ([#1464](https://github.com/supabase/gotrue/issues/1464)) ([e1cdf5c](https://github.com/supabase/gotrue/commit/e1cdf5c4b5c1bf467094f4bdcaa2e42a5cc51c20)) +* revert refactor resource owner password grant ([#1466](https://github.com/supabase/gotrue/issues/1466)) ([fa21244](https://github.com/supabase/gotrue/commit/fa21244fa929709470c2e1fc4092a9ce947399e7)) +* update file name so migration to Drop IP Address is applied ([#1447](https://github.com/supabase/gotrue/issues/1447)) ([f29e89d](https://github.com/supabase/gotrue/commit/f29e89d7d2c48ee8fd5bf8279a7fa3db0ad4d842)) + +## [2.143.0](https://github.com/supabase/gotrue/compare/v2.142.0...v2.143.0) (2024-02-19) + + +### Features + +* calculate aal without transaction ([#1437](https://github.com/supabase/gotrue/issues/1437)) ([8dae661](https://github.com/supabase/gotrue/commit/8dae6614f1a2b58819f94894cef01e9f99117769)) + + +### Bug Fixes + +* deprecate hooks ([#1421](https://github.com/supabase/gotrue/issues/1421)) ([effef1b](https://github.com/supabase/gotrue/commit/effef1b6ecc448b7927eff23df8d5b509cf16b5c)) +* error should be an IsNotFoundError ([#1432](https://github.com/supabase/gotrue/issues/1432)) ([7f40047](https://github.com/supabase/gotrue/commit/7f40047aec3577d876602444b1d88078b2237d66)) +* populate password verification attempt hook ([#1436](https://github.com/supabase/gotrue/issues/1436)) ([f974bdb](https://github.com/supabase/gotrue/commit/f974bdb58340395955ca27bdd26d57062433ece9)) +* restrict mfa enrollment to aal2 if verified factors are present ([#1439](https://github.com/supabase/gotrue/issues/1439)) ([7e10d45](https://github.com/supabase/gotrue/commit/7e10d45e54010d38677f4c3f2f224127688eb9a2)) +* update phone if autoconfirm is enabled ([#1431](https://github.com/supabase/gotrue/issues/1431)) ([95db770](https://github.com/supabase/gotrue/commit/95db770c5d2ecca4a1e960a8cb28ded37cccc100)) +* use email change email in identity ([#1429](https://github.com/supabase/gotrue/issues/1429)) ([4d3b9b8](https://github.com/supabase/gotrue/commit/4d3b9b8841b1a5fa8f3244825153cc81a73ba300)) + +## [2.142.0](https://github.com/supabase/gotrue/compare/v2.141.0...v2.142.0) (2024-02-14) + + +### Features + +* alter tag to use raw ([#1427](https://github.com/supabase/gotrue/issues/1427)) ([53cfe5d](https://github.com/supabase/gotrue/commit/53cfe5de57d4b5ab6e8e2915493856ecd96f4ede)) +* update README.md to trigger release ([#1425](https://github.com/supabase/gotrue/issues/1425)) ([91e0e24](https://github.com/supabase/gotrue/commit/91e0e245f5957ebce13370f79fd4a6be8108ed80)) + +## [2.141.0](https://github.com/supabase/gotrue/compare/v2.140.0...v2.141.0) (2024-02-13) + + +### Features + +* drop sha hash tag ([#1422](https://github.com/supabase/gotrue/issues/1422)) ([76853ce](https://github.com/supabase/gotrue/commit/76853ce6d45064de5608acc8100c67a8337ba791)) +* prefix release with v ([#1424](https://github.com/supabase/gotrue/issues/1424)) ([9d398cd](https://github.com/supabase/gotrue/commit/9d398cd75fca01fb848aa88b4f545552e8b5751a)) + +## [2.140.0](https://github.com/supabase/gotrue/compare/v2.139.2...v2.140.0) (2024-02-13) + + +### Features + +* deprecate existing webhook implementation ([#1417](https://github.com/supabase/gotrue/issues/1417)) ([5301e48](https://github.com/supabase/gotrue/commit/5301e481b0c7278c18b4578a5b1aa8d2256c2f5d)) +* update publish.yml checkout repository so there is access to Dockerfile ([#1419](https://github.com/supabase/gotrue/issues/1419)) ([7cce351](https://github.com/supabase/gotrue/commit/7cce3518e8c9f1f3f93e4f6a0658ee08771c4f1c)) + +## [2.139.2](https://github.com/supabase/gotrue/compare/v2.139.1...v2.139.2) (2024-02-08) + + +### Bug Fixes + +* improve perf in account linking ([#1394](https://github.com/supabase/gotrue/issues/1394)) ([8eedb95](https://github.com/supabase/gotrue/commit/8eedb95dbaa310aac464645ec91d6a374813ab89)) +* OIDC provider validation log message ([#1380](https://github.com/supabase/gotrue/issues/1380)) ([27e6b1f](https://github.com/supabase/gotrue/commit/27e6b1f9a4394c5c4f8dff9a8b5529db1fc67af9)) +* only create or update the email / phone identity after it's been verified ([#1403](https://github.com/supabase/gotrue/issues/1403)) ([2d20729](https://github.com/supabase/gotrue/commit/2d207296ec22dd6c003c89626d255e35441fd52d)) +* only create or update the email / phone identity after it's been verified (again) ([#1409](https://github.com/supabase/gotrue/issues/1409)) ([bc6a5b8](https://github.com/supabase/gotrue/commit/bc6a5b884b43fe6b8cb924d3f79999fe5bfe7c5f)) +* unmarshal is_private_email correctly ([#1402](https://github.com/supabase/gotrue/issues/1402)) ([47df151](https://github.com/supabase/gotrue/commit/47df15113ce8d86666c0aba3854954c24fe39f7f)) +* use `pattern` for semver docker image tags ([#1411](https://github.com/supabase/gotrue/issues/1411)) ([14a3aeb](https://github.com/supabase/gotrue/commit/14a3aeb6c3f46c8d38d98cc840112dfd0278eeda)) + + +### Reverts + +* "fix: only create or update the email / phone identity after i… ([#1407](https://github.com/supabase/gotrue/issues/1407)) ([ff86849](https://github.com/supabase/gotrue/commit/ff868493169a0d9ac18b66058a735197b1df5b9b)) diff --git a/auth_v2.169.0/CODEOWNERS b/auth_v2.169.0/CODEOWNERS new file mode 100644 index 0000000..cb9b3ca --- /dev/null +++ b/auth_v2.169.0/CODEOWNERS @@ -0,0 +1 @@ +* @supabase/auth diff --git a/auth_v2.169.0/CODE_OF_CONDUCT.md b/auth_v2.169.0/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..2867c8a --- /dev/null +++ b/auth_v2.169.0/CODE_OF_CONDUCT.md @@ -0,0 +1,74 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to making participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, gender identity and expression, level of experience, +nationality, personal appearance, race, religion, or sexual identity and +orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies both within project spaces and in public spaces +when an individual is representing the project or its community. Examples of +representing a project or community include using an official project e-mail +address, posting via an official social media account, or acting as an appointed +representative at an online or offline event. Representation of a project may be +further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at david@netlify.com. All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at [http://contributor-covenant.org/version/1/4][version] + +[homepage]: http://contributor-covenant.org +[version]: http://contributor-covenant.org/version/1/4/ diff --git a/auth_v2.169.0/CONTRIBUTING.md b/auth_v2.169.0/CONTRIBUTING.md new file mode 100644 index 0000000..22a1add --- /dev/null +++ b/auth_v2.169.0/CONTRIBUTING.md @@ -0,0 +1,523 @@ +# CONTRIBUTING + +We would love to have contributions from each and every one of you in the community be it big or small and you are the ones who motivate us to do better than what we do today. + +## Code Of Conduct + +Please help us keep all our projects open and inclusive. Kindly follow our [Code of Conduct](CODE_OF_CONDUCT.md) to keep the ecosystem healthy and friendly for all. + +## Quick Start + +Auth has a development container setup that makes it easy to get started contributing. This setup only requires that [Docker](https://www.docker.com/get-started) is setup on your system. The development container setup includes a PostgreSQL container with migrations already applied and a container running GoTrue that will perform a hot reload when changes to the source code are detected. + +If you would like to run Auth locally or learn more about what these containers are doing for you, continue reading the [Setup and Tooling](#setup-and-tooling) section below. Otherwise, you can skip ahead to the [How To Verify that GoTrue is Available](#how-to-verify-that-auth-is-available) section to learn about working with and developing GoTrue. + +Before using the containers, you will need to make sure an `.env.docker` file exists by making a copy of `example.docker.env` and configuring it for your needs. The set of env vars in `example.docker.env` only contain the necessary env vars for auth to start in a docker environment. For the full list of env vars, please refer to `example.env` and copy over the necessary ones into your `.env.docker` file. + +The following are some basic commands. A full and up to date list of commands can be found in the project's `Makefile` or by running `make help`. + +### Starting the containers + +Start the containers as described above in an attached state with log output. + +```bash +make dev +``` + +### Running tests in the containers + +Start the containers with a fresh database and run the project's tests. + +```bash +make docker-test +``` + +### Removing the containers + +Remove both containers and their volumes. This removes any data associated with the containers. + +```bash +make docker-clean +``` + +### Rebuild the containers + +Fully rebuild the containers without using any cached layers. + +```bash +make docker-build +``` + +## Setup and Tooling + +Auth -- as the name implies -- is a user registration and authentication API developed in [Go](https://go.dev). + +It connects to a [PostgreSQL](https://www.postgresql.org) database in order to store authentication data, [Soda CLI](https://gobuffalo.io/en/docs/db/toolbox) to manage database schema and migrations, +and runs inside a [Docker](https://www.docker.com/get-started) container. + +Therefore, to contribute to Auth you will need to install these tools. + +### Install Tools + +- Install [Go](https://go.dev) 1.22 + +```zsh +# Via Homebrew on macOS +brew install go@1.22 + +# Set the environment variable in the ~/.zshrc file +echo 'export PATH="/opt/homebrew/opt/go@1.22/bin:$PATH"' >> ~/.zshrc +``` + +- Install [Docker](https://www.docker.com/get-started) + +```zsh +# Via Homebrew on macOS +brew install docker +``` + +Or, if you prefer, download [Docker Desktop](https://www.docker.com/get-started). + +- Install [Soda CLI](https://gobuffalo.io/en/docs/db/toolbox) + +```zsh +# Via Homebrew on macOS +brew install gobuffalo/tap/pop +``` + +If you are on macOS Catalina you may [run into issues installing Soda with Brew](https://github.com/gobuffalo/homebrew-tap/issues/5). Do check your `GOPATH` and run + +`go build -o /bin/soda github.com/gobuffalo/pop/soda` to resolve. + +- Clone the Auth [repository](https://github.com/supabase/auth) + +```zsh +git clone https://github.com/supabase/auth +``` + +### Install Auth + +To begin installation, be sure to start from the root directory. + +- `cd auth` + +To complete installation, you will: + +- Install the PostgreSQL Docker image +- Create the DB Schema and Migrations +- Setup a local `.env` for environment variables +- Compile Auth +- Run the Auth binary executable + +#### Installation Steps + +1. Start Docker +2. To install the PostgreSQL Docker image, run: + +```zsh +# Builds the postgres image +docker-compose -f docker-compose-dev.yml build postgres + +# Runs the postgres container +docker-compose -f docker-compose-dev.yml up postgres +``` + +You should then see in Docker that `auth_postgresql` is running on `port: 5432`. + +> **Important** If you happen to already have a local running instance of Postgres running on the port `5432` because you +> may have installed via [homebrew on macOS](https://formulae.brew.sh/formula/postgresql) then be certain to stop the process using: +> +> - `brew services stop postgresql` +> +> If you need to run the test environment on another port, you will need to modify several configuration files to use a different custom port. + +3. Next compile the Auth binary: + +When you fork a repository, GitHub does not automatically copy all the tags (tags are not included by default). To ensure the correct tag is set before building the binary, you need to fetch the tags from the upstream repository and push them to your fork. Follow these steps: + +```zsh +# Fetch the tags from the upstream repository +git fetch upstream --tags + +# Push the tags to your fork +git push origin --tags +``` + +Then build the binary by running: + +```zsh +make build +``` + +4. To setup the database schema via Soda, run: + +```zsh +make migrate_test +``` + +You should see log messages that indicate that the Auth migrations were applied successfully: + +```terminal +INFO[0000] Auth migrations applied successfully +DEBU[0000] after status +[POP] 2021/12/15 10:44:36 sql - SELECT EXISTS (SELECT schema_migrations.* FROM schema_migrations AS schema_migrations WHERE version = $1) | ["20210710035447"] +[POP] 2021/12/15 10:44:36 sql - SELECT EXISTS (SELECT schema_migrations.* FROM schema_migrations AS schema_migrations WHERE version = $1) | ["20210722035447"] +[POP] 2021/12/15 10:44:36 sql - SELECT EXISTS (SELECT schema_migrations.* FROM schema_migrations AS schema_migrations WHERE version = $1) | ["20210730183235"] +[POP] 2021/12/15 10:44:36 sql - SELECT EXISTS (SELECT schema_migrations.* FROM schema_migrations AS schema_migrations WHERE version = $1) | ["20210909172000"] +[POP] 2021/12/15 10:44:36 sql - SELECT EXISTS (SELECT schema_migrations.* FROM schema_migrations AS schema_migrations WHERE version = $1) | ["20211122151130"] +Version Name Status +20210710035447 alter_users Applied +20210722035447 adds_confirmed_at Applied +20210730183235 add_email_change_confirmed Applied +20210909172000 create_identities_table Applied +20211122151130 create_user_id_idx Applied +``` + +That lists each migration that was applied. Note: there may be more migrations than those listed. + +4. Create a `.env` file in the root of the project and copy the following config in [example.env](example.env). Set the values to GOTRUE_SMS_TEST_OTP_VALID_UNTIL in the `.env` file. + +5. In order to have Auth connect to your PostgreSQL database running in Docker, it is important to set a connection string like: + +``` +DATABASE_URL="postgres://supabase_auth_admin:root@localhost:5432/postgres" +``` + +> Important: Auth requires a set of SMTP credentials to run, you can generate your own SMTP credentials via an SMTP provider such as AWS SES, SendGrid, MailChimp, SendInBlue or any other SMTP providers. + +6. Then finally Start Auth +7. Verify that Auth is Available + +### Starting Auth + +Start Auth by running the executable: + +```zsh +./auth +``` + +This command will re-run migrations and then indicate that Auth has started: + +```zsh +INFO[0000] Auth API started on: localhost:9999 +``` + +### How To Verify that Auth is Available + +To test that your Auth is up and available, you can query the `health` endpoint at `http://localhost:9999/health`. You should see a response similar to: + +```json +{ + "description": "Auth is a user registration and authentication API", + "name": "Auth", + "version": "" +} +``` + +To see the current settings, make a request to `http://localhost:9999/settings` and you should see a response similar to: + +```json +{ + "external": { + "apple": false, + "azure": false, + "bitbucket": false, + "discord": false, + "github": false, + "gitlab": false, + "google": false, + "facebook": false, + "spotify": false, + "slack": false, + "slack_oidc": false, + "twitch": true, + "twitter": false, + "email": true, + "phone": false, + "saml": false + }, + "external_labels": { + "saml": "auth0" + }, + "disable_signup": false, + "mailer_autoconfirm": false, + "phone_autoconfirm": false, + "sms_provider": "twilio" +} +``` + +## How to Use Admin API Endpoints + +To test the admin endpoints (or other api endpoints), you can invoke via HTTP requests. Using [Insomnia](https://insomnia.rest/products/insomnia) can help you issue these requests. + +You will need to know the `GOTRUE_JWT_SECRET` configured in the `.env` settings. + +Also, you must generate a JWT with the signature which has the `supabase_admin` role (or one that is specified in `GOTRUE_JWT_ADMIN_ROLES`). + +For example: + +```json +{ + "role": "supabase_admin" +} +``` + +You can sign this payload using the [JWT.io Debugger](https://jwt.io/#debugger-io) but make sure that `secret base64 encoded` is unchecked. + +Then you can use this JWT as a Bearer token for admin requests. + +### Create User (aka Sign Up a User) + +To create a new user, `POST /admin/users` with the payload: + +```json +{ + "email": "user@example.com", + "password": "12345678" +} +``` + +#### Request + +``` +POST /admin/users HTTP/1.1 +Host: localhost:9999 +User-Agent: insomnia/2021.7.2 +Content-Type: application/json +Authorization: Bearer +Accept: */* +Content-Length: 57 +``` + +#### Response + +And you should get a new user: + +```json +{ + "id": "e78c512d-68e4-482b-901b-75003e89acae", + "aud": "authenticated", + "role": "authenticated", + "email": "user@example.com", + "phone": "", + "app_metadata": { + "provider": "email", + "providers": ["email"] + }, + "user_metadata": {}, + "identities": null, + "created_at": "2021-12-15T12:40:03.507551-05:00", + "updated_at": "2021-12-15T12:40:03.512067-05:00" +} +``` + +### List/Find Users + +To create a new user, make a request to `GET /admin/users`. + +#### Request + +``` +GET /admin/users HTTP/1.1 +Host: localhost:9999 +User-Agent: insomnia/2021.7.2 +Authorization: Bearer +Accept: */\_ +``` + +#### Response + +The response from `/admin/users` should return all users: + +```json +{ + "aud": "authenticated", + "users": [ + { + "id": "b7fd0253-6e16-4d4e-b61b-5943cb1b2102", + "aud": "authenticated", + "role": "authenticated", + "email": "user+4@example.com", + "phone": "", + "app_metadata": { + "provider": "email", + "providers": ["email"] + }, + "user_metadata": {}, + "identities": null, + "created_at": "2021-12-15T12:43:58.12207-05:00", + "updated_at": "2021-12-15T12:43:58.122073-05:00" + }, + { + "id": "d69ae847-99be-4642-868f-439c2cdd9af4", + "aud": "authenticated", + "role": "authenticated", + "email": "user+3@example.com", + "phone": "", + "app_metadata": { + "provider": "email", + "providers": ["email"] + }, + "user_metadata": {}, + "identities": null, + "created_at": "2021-12-15T12:43:56.730209-05:00", + "updated_at": "2021-12-15T12:43:56.730213-05:00" + }, + { + "id": "7282cf42-344e-4474-bdf6-d48e4968a2e4", + "aud": "authenticated", + "role": "authenticated", + "email": "user+2@example.com", + "phone": "", + "app_metadata": { + "provider": "email", + "providers": ["email"] + }, + "user_metadata": {}, + "identities": null, + "created_at": "2021-12-15T12:43:54.867676-05:00", + "updated_at": "2021-12-15T12:43:54.867679-05:00" + }, + { + "id": "e78c512d-68e4-482b-901b-75003e89acae", + "aud": "authenticated", + "role": "authenticated", + "email": "user@example.com", + "phone": "", + "app_metadata": { + "provider": "email", + "providers": ["email"] + }, + "user_metadata": {}, + "identities": null, + "created_at": "2021-12-15T12:40:03.507551-05:00", + "updated_at": "2021-12-15T12:40:03.507554-05:00" + } + ] +} +``` + +### Running Database Migrations + +If you need to run any new migrations: + +```zsh +make migrate_test +``` + +## Testing + +Currently, we don't use a separate test database, so the same database created when installing Auth to run locally is used. + +The following commands should help in setting up a database and running the tests: + +```sh +# Runs the database in a docker container +$ docker-compose -f docker-compose-dev.yml up postgres + +# Applies the migrations to the database (requires soda cli) +$ make migrate_test + +# Executes the tests +$ make test +``` + +### Customizing the PostgreSQL Port + +if you already run PostgreSQL and need to run your database on a different, custom port, +you will need to make several configuration changes to the following files: + +In these examples, we change the port from 5432 to 7432. + +> Note: This is not recommended, but if you do, please do not check in changes. + +``` +// file: docker-compose-dev.yml +ports: + - 7432:5432 \ 👈 set the first value to your external facing port +``` + +The port you customize here can them be used in the subsequent configuration: + +``` +// file: database.yaml +test: +dialect: "postgres" +database: "postgres" +host: {{ envOr "POSTGRES_HOST" "127.0.0.1" }} +port: {{ envOr "POSTGRES_PORT" "7432" }} 👈 set to your port +``` + +``` +// file: test.env +DATABASE_URL="postgres://supabase_auth_admin:root@localhost:7432/postgres" 👈 set to your port +``` + +``` +// file: migrate.sh +export GOTRUE_DB_DATABASE_URL="postgres://supabase_auth_admin:root@localhost:7432/$DB_ENV" +``` + +## Helpful Docker Commands + +``` +// file: docker-compose-dev.yml +container_name: auth_postgres +``` + +```zsh +# Command line into bash on the PostgreSQL container +docker exec -it auth_postgres bash + +# Removes Container +docker container rm -f auth_postgres + +# Removes volume +docker volume rm postgres_data +``` + +## Updating Package Dependencies + +- `make deps` +- `go mod tidy` if necessary + +## Submitting Pull Requests + +We actively welcome your pull requests. + +- Fork the repo and create your branch from `master`. +- If you've added code that should be tested, add tests. +- If you've changed APIs, update the documentation. +- Ensure the test suite passes. +- Make sure your code lints. + +### Checklist for Submitting Pull Requests + +- Is there a corresponding issue created for it? If so, please include it in the PR description so we can track / refer to it. +- Does your PR follow the [semantic-release commit guidelines](https://github.com/angular/angular.js/blob/master/DEVELOPERS.md#-git-commit-guidelines)? +- If the PR is a `feat`, an [RFC](https://github.com/supabase/rfcs) or a detailed description of the design implementation is required. The former (RFC) is preferred before starting on the PR. +- Are the existing tests passing? +- Have you written some tests for your PR? + +## Guidelines for Implementing Additional OAuth Providers + +> ⚠️ We won't be accepting any additional oauth / sms provider contributions for now because we intend to support these through webhooks or a generic provider in the future. + +Please ensure that an end-to-end test is done for the OAuth provider implemented. + +An end-to-end test includes: + +- Creating an application on the oauth provider site +- Generating your own client_id and secret +- Testing that `http://localhost:9999/authorize?provider=MY_COOL_NEW_PROVIDER` redirects you to the provider sign-in page +- The callback is handled properly +- Gotrue redirects to the `SITE_URL` or one of the URI's specified in the `URI_ALLOW_LIST` with the access_token, provider_token, expiry and refresh_token as query fragments + +### Writing tests for the new OAuth provider implemented + +Since implementing an additional OAuth provider consists of making api calls to an external api, we set up a mock server to attempt to mock the responses expected from the OAuth provider. + +## License + +By contributing to Auth, you agree that your contributions will be licensed +under its [MIT license](LICENSE). diff --git a/auth_v2.169.0/Dockerfile b/auth_v2.169.0/Dockerfile new file mode 100644 index 0000000..17d5071 --- /dev/null +++ b/auth_v2.169.0/Dockerfile @@ -0,0 +1,32 @@ +FROM golang:1.22.3-alpine3.20 as build +ENV GO111MODULE=on +ENV CGO_ENABLED=0 +ENV GOOS=linux + +RUN apk add --no-cache make git + +WORKDIR /go/src/github.com/supabase/auth + +# Pulling dependencies +COPY ./Makefile ./go.* ./ +RUN make deps + +# Building stuff +COPY . /go/src/github.com/supabase/auth + +# Make sure you change the RELEASE_VERSION value before publishing an image. +RUN RELEASE_VERSION=unspecified make build + +# Always use alpine:3 so the latest version is used. This will keep CA certs more up to date. +FROM alpine:3 +RUN adduser -D -u 1000 supabase + +RUN apk add --no-cache ca-certificates +COPY --from=build /go/src/github.com/supabase/auth/auth /usr/local/bin/auth +COPY --from=build /go/src/github.com/supabase/auth/migrations /usr/local/etc/auth/migrations/ +RUN ln -s /usr/local/bin/auth /usr/local/bin/gotrue + +ENV GOTRUE_DB_MIGRATIONS_PATH /usr/local/etc/auth/migrations + +USER supabase +CMD ["auth"] diff --git a/auth_v2.169.0/Dockerfile.dev b/auth_v2.169.0/Dockerfile.dev new file mode 100644 index 0000000..d2733aa --- /dev/null +++ b/auth_v2.169.0/Dockerfile.dev @@ -0,0 +1,18 @@ +FROM golang:1.22.3-alpine3.20 +ENV GO111MODULE=on +ENV CGO_ENABLED=0 +ENV GOOS=linux + +RUN apk add --no-cache make git bash + +WORKDIR /go/src/github.com/supabase/auth + +# Pulling dependencies +COPY ./Makefile ./go.* ./ + +# Production dependencies +RUN make deps + +# Development dependences +RUN go get github.com/githubnemo/CompileDaemon +RUN go install github.com/githubnemo/CompileDaemon diff --git a/auth_v2.169.0/Dockerfile.postgres.dev b/auth_v2.169.0/Dockerfile.postgres.dev new file mode 100644 index 0000000..58661ef --- /dev/null +++ b/auth_v2.169.0/Dockerfile.postgres.dev @@ -0,0 +1,8 @@ +FROM postgres:15 +WORKDIR / +RUN pwd +COPY init_postgres.sh /docker-entrypoint-initdb.d/init.sh +RUN chmod +x /docker-entrypoint-initdb.d/init.sh +EXPOSE 5432 + +CMD ["postgres"] diff --git a/auth_v2.169.0/LICENSE b/auth_v2.169.0/LICENSE new file mode 100644 index 0000000..8a7d702 --- /dev/null +++ b/auth_v2.169.0/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Supabase + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/auth_v2.169.0/Makefile b/auth_v2.169.0/Makefile new file mode 100644 index 0000000..7bfd8ab --- /dev/null +++ b/auth_v2.169.0/Makefile @@ -0,0 +1,93 @@ +.PHONY: all build deps dev-deps image migrate test vet sec format unused +CHECK_FILES?=./... + +FLAGS=-ldflags "-X github.com/supabase/auth/internal/utilities.Version=`git describe --tags`" -buildvcs=false +ifdef RELEASE_VERSION + FLAGS=-ldflags "-X github.com/supabase/auth/internal/utilities.Version=v$(RELEASE_VERSION)" -buildvcs=false +endif + +ifneq ($(shell docker compose version 2>/dev/null),) + DOCKER_COMPOSE=docker compose +else + DOCKER_COMPOSE=docker-compose +endif + +DEV_DOCKER_COMPOSE:=docker-compose-dev.yml + +help: ## Show this help. + @awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z_-]+:.*?## / {sub("\\\\n",sprintf("\n%22c"," "), $$2);printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}' $(MAKEFILE_LIST) + +all: vet sec static build ## Run the tests and build the binary. + +build: deps ## Build the binary. + CGO_ENABLED=0 go build $(FLAGS) + CGO_ENABLED=0 GOOS=linux GOARCH=arm64 go build $(FLAGS) -o auth-arm64 + +dev-deps: ## Install developer dependencies + @go install github.com/gobuffalo/pop/soda@latest + @go install github.com/securego/gosec/v2/cmd/gosec@latest + @go install honnef.co/go/tools/cmd/staticcheck@latest + @go install github.com/deepmap/oapi-codegen/cmd/oapi-codegen@latest + @go install github.com/nishanths/exhaustive/cmd/exhaustive@latest + +deps: ## Install dependencies. + @go mod download + @go mod verify + +migrate_dev: ## Run database migrations for development. + hack/migrate.sh postgres + +migrate_test: ## Run database migrations for test. + hack/migrate.sh postgres + +test: build ## Run tests. + go test $(CHECK_FILES) -coverprofile=coverage.out -coverpkg ./... -p 1 -race -v -count=1 + ./hack/coverage.sh + +vet: # Vet the code + go vet $(CHECK_FILES) + +sec: dev-deps # Check for security vulnerabilities + gosec -quiet -exclude-generated $(CHECK_FILES) + gosec -quiet -tests -exclude-generated -exclude=G104 $(CHECK_FILES) + +unused: dev-deps # Look for unused code + @echo "Unused code:" + staticcheck -checks U1000 $(CHECK_FILES) + + @echo + + @echo "Code used only in _test.go (do move it in those files):" + staticcheck -checks U1000 -tests=false $(CHECK_FILES) + +static: dev-deps + staticcheck ./... + exhaustive ./... + +generate: dev-deps + go generate ./... + +dev: ## Run the development containers + ${DOCKER_COMPOSE} -f $(DEV_DOCKER_COMPOSE) up + +down: ## Shutdown the development containers + # Start postgres first and apply migrations + ${DOCKER_COMPOSE} -f $(DEV_DOCKER_COMPOSE) down + +docker-test: ## Run the tests using the development containers + ${DOCKER_COMPOSE} -f $(DEV_DOCKER_COMPOSE) up -d postgres + ${DOCKER_COMPOSE} -f $(DEV_DOCKER_COMPOSE) run auth sh -c "make migrate_test" + ${DOCKER_COMPOSE} -f $(DEV_DOCKER_COMPOSE) run auth sh -c "make test" + ${DOCKER_COMPOSE} -f $(DEV_DOCKER_COMPOSE) down -v + +docker-build: ## Force a full rebuild of the development containers + ${DOCKER_COMPOSE} -f $(DEV_DOCKER_COMPOSE) build --no-cache + ${DOCKER_COMPOSE} -f $(DEV_DOCKER_COMPOSE) up -d postgres + ${DOCKER_COMPOSE} -f $(DEV_DOCKER_COMPOSE) run auth sh -c "make migrate_dev" + ${DOCKER_COMPOSE} -f $(DEV_DOCKER_COMPOSE) down + +docker-clean: ## Remove the development containers and volumes + ${DOCKER_COMPOSE} -f $(DEV_DOCKER_COMPOSE) rm -fsv + +format: + gofmt -s -w . diff --git a/auth_v2.169.0/README.md b/auth_v2.169.0/README.md new file mode 100644 index 0000000..5b9c702 --- /dev/null +++ b/auth_v2.169.0/README.md @@ -0,0 +1,1229 @@ +# Auth - Authentication and User Management by Supabase + +[![Coverage Status](https://coveralls.io/repos/github/supabase/auth/badge.svg?branch=master)](https://coveralls.io/github/supabase/auth?branch=master) + +Auth is a user management and authentication server written in Go that powers +[Supabase](https://supabase.com)'s features such as: + +- Issuing JWTs +- Row Level Security with PostgREST +- User management +- Sign in with email, password, magic link, phone number +- Sign in with external providers (Google, Apple, Facebook, Discord, ...) + +It is originally based on the excellent +[GoTrue codebase by Netlify](https://github.com/netlify/gotrue), however both have diverged significantly in features and capabilities. + +If you wish to contribute to the project, please refer to the [contributing guide](/CONTRIBUTING.md). + +## Table Of Contents + +- [Quick Start](#quick-start) +- [Running in Production](#running-in-production) +- [Configuration](#configuration) +- [Endpoints](#endpoints) + +## Quick Start + +Create a `.env` file to store your own custom env vars. See [`example.env`](example.env) + +1. Start the local postgres database in a postgres container: `docker-compose -f docker-compose-dev.yml up postgres` +2. Build the auth binary: `make build` . You should see an output like this: + +```bash +go build -ldflags "-X github.com/supabase/auth/cmd.Version=`git rev-parse HEAD`" +GOOS=linux GOARCH=arm64 go build -ldflags "-X github.com/supabase/auth/cmd.Version=`git rev-parse HEAD`" -o gotrue-arm64 +``` + +3. Execute the auth binary: `./auth` + +### If you have docker installed + +Create a `.env.docker` file to store your own custom env vars. See [`example.docker.env`](example.docker.env) + +1. `make build` +2. `make dev` +3. `docker ps` should show 2 docker containers (`auth_postgresql` and `gotrue_gotrue`) +4. That's it! Visit the [health checkendpoint](http://localhost:9999/health) to confirm that auth is running. + +## Running in production + +Running an authentication server in production is not an easy feat. We +recommend using [Supabase Auth](https://supabase.com/auth) which gets regular +security updates. + +Otherwise, please make sure you setup a process to promptly update to the +latest version. You can do that by following this repository, specifically the +[Releases](https://github.com/supabase/auth/releases) and [Security +Advisories](https://github.com/supabase/auth/security/advisories) sections. + +### Backward compatibility + +Auth uses the [Semantic Versioning](https://semver.org) scheme. Here are some +further clarifications on backward compatibility guarantees: + +**Go API compatibility** + +Auth is not meant to be used as a Go library. There are no guarantees on +backward API compatibility when used this way regardless which version number +changes. + +**Patch** + +Changes to the patch version guarantees backward compatibility with: + +- Database objects (tables, columns, indexes, functions). +- REST API +- JWT structure +- Configuration + +Guaranteed examples: + +- A column won't change its type. +- A table won't change its primary key. +- An index will not be removed. +- A uniqueness constraint will not be removed. +- A REST API will not be removed. +- Parameters to REST APIs will work equivalently as before (or better, if a bug + has been fixed). +- Configuration will not change. + +Not guaranteed examples: + +- A table may add new columns. +- Columns in a table may be reordered. +- Non-unique constraints may be removed (database level checks, null, default + values). +- JWT may add new properties. + +**Minor** + +Changes to minor version guarantees backward compatibility with: + +- REST API +- JWT structure +- Configuration + +Exceptions to these guarantees will be made only when serious security issues +are found that can't be remedied in any other way. + +Guaranteed examples: + +- Existing APIs may be deprecated but continue working for the next few minor + version releases. +- Configuration changes may become deprecated but continue working for the next + few minor version releases. +- Already issued JWTs will be accepted, but new JWTs may be with a different + structure (but usually similar). + +Not guaranteed examples: + +- Removal of JWT fields after a deprecation notice. +- Removal of certain APIs after a deprecation notice. +- Removal of sign-in with external providers, after a deprecation notice. +- Deletion, truncation, significant schema changes to tables, indexes, views, + functions. + +We aim to provide a deprecation notice in execution logs for at least two major +version releases or two weeks if multiple releases go out. Compatibility will +be guaranteed while the notice is live. + +**Major** + +Changes to the major version do not guarantee any backward compatibility with +previous versions. + +### Inherited features + +Certain inherited features from the Netlify codebase are not supported by +Supabase and they may be removed without prior notice in the future. This is a +comprehensive list of those features: + +1. Multi-tenancy via the `instances` table i.e. `GOTRUE_MULTI_INSTANCE_MODE` + configuration parameter. +2. System user (zero UUID user). +3. Super admin via the `is_super_admin` column. +4. Group information in JWTs via `GOTRUE_JWT_ADMIN_GROUP_NAME` and other + configuration fields. +5. Symmetrics JWTs. In the future it is very likely that Auth will begin + issuing asymmetric JWTs (subject to configuration), so do not rely on the + assumption that only HS256 signed JWTs will be issued long term. + +Note that this is not an exhaustive list and it may change. + +### Best practices when self-hosting + +These are some best practices to follow when self-hosting to ensure backward +compatibility with Auth: + +1. Do not modify the schema managed by Auth. You can see all of the + migrations in the `migrations` directory. +2. Do not rely on schema and structure of data in the database. Always use + Auth APIs and JWTs to infer information about users. +3. Always run Auth behind a TLS-capable proxy such as a load balancer, CDN, + nginx or other similar software. + +## Configuration + +You may configure Auth using either a configuration file named `.env`, +environment variables, or a combination of both. Environment variables are prefixed with `GOTRUE_`, and will always have precedence over values provided via file. + +### Top-Level + +```properties +GOTRUE_SITE_URL=https://example.netlify.com/ +``` + +`SITE_URL` - `string` **required** + +The base URL your site is located at. Currently used in combination with other settings to construct URLs used in emails. Any URI that shares a host with `SITE_URL` is a permitted value for `redirect_to` params (see `/authorize` etc.). + +`URI_ALLOW_LIST` - `string` + +A comma separated list of URIs (e.g. `"https://foo.example.com,https://*.foo.example.com,https://bar.example.com"`) which are permitted as valid `redirect_to` destinations. Defaults to []. Supports wildcard matching through globbing. e.g. `https://*.foo.example.com` will allow `https://a.foo.example.com` and `https://b.foo.example.com` to be accepted. Globbing is also supported on subdomains. e.g. `https://foo.example.com/*` will allow `https://foo.example.com/page1` and `https://foo.example.com/page2` to be accepted. + +For more common glob patterns, check out the [following link](https://pkg.go.dev/github.com/gobwas/glob#Compile). + +`OPERATOR_TOKEN` - `string` _Multi-instance mode only_ + +The shared secret with an operator (usually Netlify) for this microservice. Used to verify requests have been proxied through the operator and +the payload values can be trusted. + +`DISABLE_SIGNUP` - `bool` + +When signup is disabled the only way to create new users is through invites. Defaults to `false`, all signups enabled. + +`GOTRUE_EXTERNAL_EMAIL_ENABLED` - `bool` + +Use this to disable email signups (users can still use external oauth providers to sign up / sign in) + +`GOTRUE_EXTERNAL_PHONE_ENABLED` - `bool` + +Use this to disable phone signups (users can still use external oauth providers to sign up / sign in) + +`GOTRUE_RATE_LIMIT_HEADER` - `string` + +Header on which to rate limit the `/token` endpoint. + +`GOTRUE_RATE_LIMIT_EMAIL_SENT` - `string` + +Rate limit the number of emails sent per hr on the following endpoints: `/signup`, `/invite`, `/magiclink`, `/recover`, `/otp`, & `/user`. + +`GOTRUE_PASSWORD_MIN_LENGTH` - `int` + +Minimum password length, defaults to 6. + +`GOTRUE_PASSWORD_REQUIRED_CHARACTERS` - a string of character sets separated by `:`. A password must contain at least one character of each set to be accepted. To use the `:` character escape it with `\`. + +`GOTRUE_SECURITY_REFRESH_TOKEN_ROTATION_ENABLED` - `bool` + +If refresh token rotation is enabled, auth will automatically detect malicious attempts to reuse a revoked refresh token. When a malicious attempt is detected, gotrue immediately revokes all tokens that descended from the offending token. + +`GOTRUE_SECURITY_REFRESH_TOKEN_REUSE_INTERVAL` - `string` + +This setting is only applicable if `GOTRUE_SECURITY_REFRESH_TOKEN_ROTATION_ENABLED` is enabled. The reuse interval for a refresh token allows for exchanging the refresh token multiple times during the interval to support concurrency or offline issues. During the reuse interval, auth will not consider using a revoked token as a malicious attempt and will simply return the child refresh token. + +Only the previous revoked token can be reused. Using an old refresh token way before the current valid refresh token will trigger the reuse detection. + +### API + +```properties +GOTRUE_API_HOST=localhost +PORT=9999 +API_EXTERNAL_URL=http://localhost:9999 +``` + +`API_HOST` - `string` + +Hostname to listen on. + +`PORT` (no prefix) / `API_PORT` - `number` + +Port number to listen on. Defaults to `8081`. + +`API_ENDPOINT` - `string` _Multi-instance mode only_ + +Controls what endpoint Netlify can access this API on. + +`API_EXTERNAL_URL` - `string` **required** + +The URL on which Gotrue might be accessed at. + +`REQUEST_ID_HEADER` - `string` + +If you wish to inherit a request ID from the incoming request, specify the name in this value. + +### Database + +```properties +GOTRUE_DB_DRIVER=postgres +DATABASE_URL=root@localhost/auth +``` + +`DB_DRIVER` - `string` **required** + +Chooses what dialect of database you want. Must be `postgres`. + +`DATABASE_URL` (no prefix) / `DB_DATABASE_URL` - `string` **required** + +Connection string for the database. + +`GOTRUE_DB_MAX_POOL_SIZE` - `int` + +Sets the maximum number of open connections to the database. Defaults to 0 which is equivalent to an "unlimited" number of connections. + +`DB_NAMESPACE` - `string` + +Adds a prefix to all table names. + +**Migrations Note** + +Migrations are applied automatically when you run `./auth`. However, you also have the option to rerun the migrations via the following methods: + +- If built locally: `./auth migrate` +- Using Docker: `docker run --rm auth gotrue migrate` + +### Logging + +```properties +LOG_LEVEL=debug # available without GOTRUE prefix (exception) +GOTRUE_LOG_FILE=/var/log/go/auth.log +``` + +`LOG_LEVEL` - `string` + +Controls what log levels are output. Choose from `panic`, `fatal`, `error`, `warn`, `info`, or `debug`. Defaults to `info`. + +`LOG_FILE` - `string` + +If you wish logs to be written to a file, set `log_file` to a valid file path. + +### Observability + +Auth has basic observability built in. It is able to export +[OpenTelemetry](https://opentelemetry.io) metrics and traces to a collector. + +#### Tracing + +To enable tracing configure these variables: + +`GOTRUE_TRACING_ENABLED` - `boolean` + +`GOTRUE_TRACING_EXPORTER` - `string` only `opentelemetry` supported + +Make sure you also configure the [OpenTelemetry +Exporter](https://opentelemetry.io/docs/reference/specification/protocol/exporter/) +configuration for your collector or service. + +For example, if you use +[Honeycomb.io](https://docs.honeycomb.io/getting-data-in/opentelemetry/go-distro/#using-opentelemetry-without-the-honeycomb-distribution) +you should set these standard OpenTelemetry OTLP variables: + +``` +OTEL_SERVICE_NAME=auth +OTEL_EXPORTER_OTLP_PROTOCOL=grpc +OTEL_EXPORTER_OTLP_ENDPOINT=https://api.honeycomb.io:443 +OTEL_EXPORTER_OTLP_HEADERS="x-honeycomb-team=,x-honeycomb-dataset=auth" +``` + +#### Metrics + +To enable metrics configure these variables: + +`GOTRUE_METRICS_ENABLED` - `boolean` + +`GOTRUE_METRICS_EXPORTER` - `string` only `opentelemetry` and `prometheus` +supported + +Make sure you also configure the [OpenTelemetry +Exporter](https://opentelemetry.io/docs/reference/specification/protocol/exporter/) +configuration for your collector or service. + +If you use the `prometheus` exporter, the server host and port can be +configured using these standard OpenTelemetry variables: + +`OTEL_EXPORTER_PROMETHEUS_HOST` - IP address, default `0.0.0.0` + +`OTEL_EXPORTER_PROMETHEUS_PORT` - port number, default `9100` + +The metrics are exported on the `/` path on the server. + +If you use the `opentelemetry` exporter, the metrics are pushed to the +collector. + +For example, if you use +[Honeycomb.io](https://docs.honeycomb.io/getting-data-in/opentelemetry/go-distro/#using-opentelemetry-without-the-honeycomb-distribution) +you should set these standard OpenTelemetry OTLP variables: + +``` +OTEL_SERVICE_NAME=auth +OTEL_EXPORTER_OTLP_PROTOCOL=grpc +OTEL_EXPORTER_OTLP_ENDPOINT=https://api.honeycomb.io:443 +OTEL_EXPORTER_OTLP_HEADERS="x-honeycomb-team=,x-honeycomb-dataset=auth" +``` + +Note that Honeycomb.io requires a paid plan to ingest metrics. + +If you need to debug an issue with traces or metrics not being pushed, you can +set `DEBUG=true` to get more insights from the OpenTelemetry SDK. + +#### Custom resource attributes + +When using the OpenTelemetry tracing or metrics exporter you can define custom +resource attributes using the [standard `OTEL_RESOURCE_ATTRIBUTES` environment +variable](https://opentelemetry.io/docs/reference/specification/resource/sdk/#specifying-resource-information-via-an-environment-variable). + +A default attribute `auth.version` is provided containing the build version. + +#### Tracing HTTP routes + +All HTTP calls to the Auth API are traced. Routes use the parametrized +version of the route, and the values for the route parameters can be found as +the `http.route.params.` span attribute. + +For example, the following request: + +``` +GET /admin/users/4acde936-82dc-4552-b851-831fb8ce0927/ +``` + +will be traced as: + +``` +http.method = GET +http.route = /admin/users/{user_id} +http.route.params.user_id = 4acde936-82dc-4552-b851-831fb8ce0927 +``` + +#### Go runtime and HTTP metrics + +All of the Go runtime metrics are exposed. Some HTTP metrics are also collected +by default. + +### JSON Web Tokens (JWT) + +```properties +GOTRUE_JWT_SECRET=supersecretvalue +GOTRUE_JWT_EXP=3600 +GOTRUE_JWT_AUD=netlify +``` + +`JWT_SECRET` - `string` **required** + +The secret used to sign JWT tokens with. + +`JWT_EXP` - `number` + +How long tokens are valid for, in seconds. Defaults to 3600 (1 hour). + +`JWT_AUD` - `string` + +The default JWT audience. Use audiences to group users. + +`JWT_ADMIN_GROUP_NAME` - `string` + +The name of the admin group (if enabled). Defaults to `admin`. + +`JWT_DEFAULT_GROUP_NAME` - `string` + +The default group to assign all new users to. + +### External Authentication Providers + +We support `apple`, `azure`, `bitbucket`, `discord`, `facebook`, `figma`, `github`, `gitlab`, `google`, `keycloak`, `linkedin`, `notion`, `spotify`, `slack`, `twitch`, `twitter` and `workos` for external authentication. + +Use the names as the keys underneath `external` to configure each separately. + +```properties +GOTRUE_EXTERNAL_GITHUB_ENABLED=true +GOTRUE_EXTERNAL_GITHUB_CLIENT_ID=myappclientid +GOTRUE_EXTERNAL_GITHUB_SECRET=clientsecretvaluessssh +GOTRUE_EXTERNAL_GITHUB_REDIRECT_URI=http://localhost:3000/callback +``` + +No external providers are required, but you must provide the required values if you choose to enable any. + +`EXTERNAL_X_ENABLED` - `bool` + +Whether this external provider is enabled or not + +`EXTERNAL_X_CLIENT_ID` - `string` **required** + +The OAuth2 Client ID registered with the external provider. + +`EXTERNAL_X_SECRET` - `string` **required** + +The OAuth2 Client Secret provided by the external provider when you registered. + +`EXTERNAL_X_REDIRECT_URI` - `string` **required** + +The URI a OAuth2 provider will redirect to with the `code` and `state` values. + +`EXTERNAL_X_URL` - `string` + +The base URL used for constructing the URLs to request authorization and access tokens. Used by `gitlab` and `keycloak`. For `gitlab` it defaults to `https://gitlab.com`. For `keycloak` you need to set this to your instance, for example: `https://keycloak.example.com/realms/myrealm` + +#### Apple OAuth + +To try out external authentication with Apple locally, you will need to do the following: + +1. Remap localhost to \ in your `/etc/hosts` config. +2. Configure auth to serve HTTPS traffic over localhost by replacing `ListenAndServe` in [api.go](internal/api/api.go) with: + + ``` + func (a *API) ListenAndServe(hostAndPort string) { + log := logrus.WithField("component", "api") + path, err := os.Getwd() + if err != nil { + log.Println(err) + } + server := &http.Server{ + Addr: hostAndPort, + Handler: a.handler, + } + done := make(chan struct{}) + defer close(done) + go func() { + waitForTermination(log, done) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + server.Shutdown(ctx) + }() + if err := server.ListenAndServeTLS("PATH_TO_CRT_FILE", "PATH_TO_KEY_FILE"); err != http.ErrServerClosed { + log.WithError(err).Fatal("http server listen failed") + } + } + ``` + +3. Generate the crt and key file. See [here](https://www.freecodecamp.org/news/how-to-get-https-working-on-your-local-development-environment-in-5-minutes-7af615770eec/) for more information. +4. Generate the `GOTRUE_EXTERNAL_APPLE_SECRET` by following this [post](https://medium.com/identity-beyond-borders/how-to-configure-sign-in-with-apple-77c61e336003)! + +### E-Mail + +Sending email is not required, but highly recommended for password recovery. +If enabled, you must provide the required values below. + +```properties +GOTRUE_SMTP_HOST=smtp.mandrillapp.com +GOTRUE_SMTP_PORT=587 +GOTRUE_SMTP_USER=smtp-delivery@example.com +GOTRUE_SMTP_PASS=correcthorsebatterystaple +GOTRUE_SMTP_ADMIN_EMAIL=support@example.com +GOTRUE_MAILER_SUBJECTS_CONFIRMATION="Please confirm" +``` + +`SMTP_ADMIN_EMAIL` - `string` **required** + +The `From` email address for all emails sent. + +`SMTP_HOST` - `string` **required** + +The mail server hostname to send emails through. + +`SMTP_PORT` - `number` **required** + +The port number to connect to the mail server on. + +`SMTP_USER` - `string` + +If the mail server requires authentication, the username to use. + +`SMTP_PASS` - `string` + +If the mail server requires authentication, the password to use. + +`SMTP_MAX_FREQUENCY` - `number` + +Controls the minimum amount of time that must pass before sending another signup confirmation or password reset email. The value is the number of seconds. Defaults to 900 (15 minutes). + +`SMTP_SENDER_NAME` - `string` + +Sets the name of the sender. Defaults to the `SMTP_ADMIN_EMAIL` if not used. + +`MAILER_AUTOCONFIRM` - `bool` + +If you do not require email confirmation, you may set this to `true`. Defaults to `false`. + +`MAILER_OTP_EXP` - `number` + +Controls the duration an email link or otp is valid for. + +`MAILER_URLPATHS_INVITE` - `string` + +URL path to use in the user invite email. Defaults to `/verify`. + +`MAILER_URLPATHS_CONFIRMATION` - `string` + +URL path to use in the signup confirmation email. Defaults to `/verify`. + +`MAILER_URLPATHS_RECOVERY` - `string` + +URL path to use in the password reset email. Defaults to `/verify`. + +`MAILER_URLPATHS_EMAIL_CHANGE` - `string` + +URL path to use in the email change confirmation email. Defaults to `/verify`. + +`MAILER_SUBJECTS_INVITE` - `string` + +Email subject to use for user invite. Defaults to `You have been invited`. + +`MAILER_SUBJECTS_CONFIRMATION` - `string` + +Email subject to use for signup confirmation. Defaults to `Confirm Your Signup`. + +`MAILER_SUBJECTS_RECOVERY` - `string` + +Email subject to use for password reset. Defaults to `Reset Your Password`. + +`MAILER_SUBJECTS_MAGIC_LINK` - `string` + +Email subject to use for magic link email. Defaults to `Your Magic Link`. + +`MAILER_SUBJECTS_EMAIL_CHANGE` - `string` + +Email subject to use for email change confirmation. Defaults to `Confirm Email Change`. + +`MAILER_TEMPLATES_INVITE` - `string` + +URL path to an email template to use when inviting a user. (e.g. `https://www.example.com/path-to-email-template.html`) +`SiteURL`, `Email`, and `ConfirmationURL` variables are available. + +Default Content (if template is unavailable): + +```html +

You have been invited

+ +

+ You have been invited to create a user on {{ .SiteURL }}. Follow this link to + accept the invite: +

+

Accept the invite

+``` + +`MAILER_TEMPLATES_CONFIRMATION` - `string` + +URL path to an email template to use when confirming a signup. (e.g. `https://www.example.com/path-to-email-template.html`) +`SiteURL`, `Email`, and `ConfirmationURL` variables are available. + +Default Content (if template is unavailable): + +```html +

Confirm your signup

+ +

Follow this link to confirm your user:

+

Confirm your mail

+``` + +`MAILER_TEMPLATES_RECOVERY` - `string` + +URL path to an email template to use when resetting a password. (e.g. `https://www.example.com/path-to-email-template.html`) +`SiteURL`, `Email`, and `ConfirmationURL` variables are available. + +Default Content (if template is unavailable): + +```html +

Reset Password

+ +

Follow this link to reset the password for your user:

+

Reset Password

+``` + +`MAILER_TEMPLATES_MAGIC_LINK` - `string` + +URL path to an email template to use when sending magic link. (e.g. `https://www.example.com/path-to-email-template.html`) +`SiteURL`, `Email`, and `ConfirmationURL` variables are available. + +Default Content (if template is unavailable): + +```html +

Magic Link

+ +

Follow this link to login:

+

Log In

+``` + +`MAILER_TEMPLATES_EMAIL_CHANGE` - `string` + +URL path to an email template to use when confirming the change of an email address. (e.g. `https://www.example.com/path-to-email-template.html`) +`SiteURL`, `Email`, `NewEmail`, and `ConfirmationURL` variables are available. + +Default Content (if template is unavailable): + +```html +

Confirm Change of Email

+ +

+ Follow this link to confirm the update of your email from {{ .Email }} to {{ + .NewEmail }}: +

+

Change Email

+``` + +### Phone Auth + +`SMS_AUTOCONFIRM` - `bool` + +If you do not require phone confirmation, you may set this to `true`. Defaults to `false`. + +`SMS_MAX_FREQUENCY` - `number` + +Controls the minimum amount of time that must pass before sending another sms otp. The value is the number of seconds. Defaults to 60 (1 minute)). + +`SMS_OTP_EXP` - `number` + +Controls the duration an sms otp is valid for. + +`SMS_OTP_LENGTH` - `number` + +Controls the number of digits of the sms otp sent. + +`SMS_PROVIDER` - `string` + +Available options are: `twilio`, `messagebird`, `textlocal`, and `vonage` + +Then you can use your [twilio credentials](https://www.twilio.com/docs/usage/requests-to-twilio#credentials): + +- `SMS_TWILIO_ACCOUNT_SID` +- `SMS_TWILIO_AUTH_TOKEN` +- `SMS_TWILIO_MESSAGE_SERVICE_SID` - can be set to your twilio sender mobile number + +Or Messagebird credentials, which can be obtained in the [Dashboard](https://dashboard.messagebird.com/en/developers/access): + +- `SMS_MESSAGEBIRD_ACCESS_KEY` - your Messagebird access key +- `SMS_MESSAGEBIRD_ORIGINATOR` - SMS sender (your Messagebird phone number with + or company name) + +### CAPTCHA + +- If enabled, CAPTCHA will check the request body for the `captcha_token` field and make a verification request to the CAPTCHA provider. + +`SECURITY_CAPTCHA_ENABLED` - `string` + +Whether captcha middleware is enabled + +`SECURITY_CAPTCHA_PROVIDER` - `string` + +for now the only options supported are: `hcaptcha` and `turnstile` + +- `SECURITY_CAPTCHA_SECRET` - `string` +- `SECURITY_CAPTCHA_TIMEOUT` - `string` + +Retrieve from hcaptcha or turnstile account + +### Reauthentication + +`SECURITY_UPDATE_PASSWORD_REQUIRE_REAUTHENTICATION` - `bool` + +Enforce reauthentication on password update. + +### Anonymous Sign-Ins + +`GOTRUE_EXTERNAL_ANONYMOUS_USERS_ENABLED` - `bool` + +Use this to enable/disable anonymous sign-ins. + +## Endpoints + +Auth exposes the following endpoints: + +### **GET /settings** + +Returns the publicly available settings for this auth instance. + +```json +{ + "external": { + "apple": true, + "azure": true, + "bitbucket": true, + "discord": true, + "facebook": true, + "figma": true, + "github": true, + "gitlab": true, + "google": true, + "keycloak": true, + "linkedin": true, + "notion": true, + "slack": true, + "spotify": true, + "twitch": true, + "twitter": true, + "workos": true + }, + "disable_signup": false, + "autoconfirm": false +} +``` + +### **POST, PUT /admin/users/** + +Creates (POST) or Updates (PUT) the user based on the `user_id` specified. The `ban_duration` field accepts the following time units: "ns", "us", "ms", "s", "m", "h". See [`time.ParseDuration`](https://pkg.go.dev/time#ParseDuration) for more details on the format used. + +```js +headers: +{ + "Authorization": "Bearer eyJhbGciOiJI...M3A90LCkxxtX9oNP9KZO" // requires a role claim that can be set in the GOTRUE_JWT_ADMIN_ROLES env var +} + +body: +{ + "role": "test-user", + "email": "email@example.com", + "phone": "12345678", + "password": "secret", // only if type = signup + "email_confirm": true, + "phone_confirm": true, + "user_metadata": {}, + "app_metadata": {}, + "ban_duration": "24h" or "none" // to unban a user +} +``` + +### **POST /admin/generate_link** + +Returns the corresponding email action link based on the type specified. Among other things, the response also contains the query params of the action link as separate JSON fields for convenience (along with the email OTP from which the corresponding token is generated). + +```js +headers: +{ + "Authorization": "Bearer eyJhbGciOiJI...M3A90LCkxxtX9oNP9KZO" // admin role required +} + +body: +{ + "type": "signup" or "magiclink" or "recovery" or "invite", + "email": "email@example.com", + "password": "secret", // only if type = signup + "data": { + ... + }, // only if type = signup + "redirect_to": "https://supabase.io" // Redirect URL to send the user to after an email action. Defaults to SITE_URL. + +} +``` + +Returns + +```js +{ + "action_link": "http://localhost:9999/verify?token=TOKEN&type=TYPE&redirect_to=REDIRECT_URL", + "email_otp": "EMAIL_OTP", + "hashed_token": "TOKEN", + "verification_type": "TYPE", + "redirect_to": "REDIRECT_URL", + ... +} +``` + +### **POST /signup** + +Register a new user with an email and password. + +```json +{ + "email": "email@example.com", + "password": "secret" +} +``` + +returns: + +```json +{ + "id": "11111111-2222-3333-4444-5555555555555", + "email": "email@example.com", + "confirmation_sent_at": "2016-05-15T20:49:40.882805774-07:00", + "created_at": "2016-05-15T19:53:12.368652374-07:00", + "updated_at": "2016-05-15T19:53:12.368652374-07:00" +} + +// if sign up is a duplicate then faux data will be returned +// as to not leak information about whether a given email +// has an account with your service or not +``` + +Register a new user with a phone number and password. + +```js +{ + "phone": "12345678", // follows the E.164 format + "password": "secret" +} +``` + +Returns: + +```json +{ + "id": "11111111-2222-3333-4444-5555555555555", // if duplicate sign up, this ID will be faux + "phone": "12345678", + "confirmation_sent_at": "2016-05-15T20:49:40.882805774-07:00", + "created_at": "2016-05-15T19:53:12.368652374-07:00", + "updated_at": "2016-05-15T19:53:12.368652374-07:00" +} +``` + +if AUTOCONFIRM is enabled and the sign up is a duplicate, then the endpoint will return: + +```json +{ + "code":400, + "msg":"User already registered" +} +``` + +### **POST /resend** + +Allows a user to resend an existing signup, sms, email_change or phone_change OTP. + +```json +{ + "email": "user@example.com", + "type": "signup" +} +``` + +```json +{ + "phone": "12345678", + "type": "sms" +} +``` + +returns: + +```json +{ + "message_id": "msgid123456" +} +``` + +### **POST /invite** + +Invites a new user with an email. +This endpoint requires the `service_role` or `supabase_admin` JWT set as an Auth Bearer header: + +e.g. + +```json +headers: { + "Authorization" : "Bearer eyJhbGciOiJI...M3A90LCkxxtX9oNP9KZO" +} +``` + +```json +{ + "email": "email@example.com" +} +``` + +Returns: + +```json +{ + "id": "11111111-2222-3333-4444-5555555555555", + "email": "email@example.com", + "confirmation_sent_at": "2016-05-15T20:49:40.882805774-07:00", + "created_at": "2016-05-15T19:53:12.368652374-07:00", + "updated_at": "2016-05-15T19:53:12.368652374-07:00", + "invited_at": "2016-05-15T19:53:12.368652374-07:00" +} +``` + +### **POST /verify** + +Verify a registration or a password recovery. Type can be `signup` or `recovery` or `invite` +and the `token` is a token returned from either `/signup` or `/recover`. + +```json +{ + "type": "signup", + "token": "confirmation-code-delivered-in-email" +} +``` + +`password` is required for signup verification if no existing password exists. + +Returns: + +```json +{ + "access_token": "jwt-token-representing-the-user", + "token_type": "bearer", + "expires_in": 3600, + "refresh_token": "a-refresh-token", + "type": "signup | recovery | invite" +} +``` + +Verify a phone signup or sms otp. Type should be set to `sms`. + +```json +{ + "type": "sms", + "token": "confirmation-otp-delivered-in-sms", + "redirect_to": "https://supabase.io", + "phone": "phone-number-sms-otp-was-delivered-to" +} +``` + +Returns: + +```json +{ + "access_token": "jwt-token-representing-the-user", + "token_type": "bearer", + "expires_in": 3600, + "refresh_token": "a-refresh-token" +} +``` + +### **GET /verify** + +Verify a registration or a password recovery. Type can be `signup` or `recovery` or `magiclink` or `invite` +and the `token` is a token returned from either `/signup` or `/recover` or `/magiclink`. + +query params: + +```json +{ + "type": "signup", + "token": "confirmation-code-delivered-in-email", + "redirect_to": "https://supabase.io" +} +``` + +User will be logged in and redirected to: + +```json +SITE_URL/#access_token=jwt-token-representing-the-user&token_type=bearer&expires_in=3600&refresh_token=a-refresh-token&type=invite +``` + +Your app should detect the query params in the fragment and use them to set the session (supabase-js does this automatically) + +You can use the `type` param to redirect the user to a password set form in the case of `invite` or `recovery`, +or show an account confirmed/welcome message in the case of `signup`, or direct them to some additional onboarding flow + +### **POST /otp** + +One-Time-Password. Will deliver a magiclink or sms otp to the user depending on whether the request body contains an "email" or "phone" key. + +If `"create_user": true`, user will not be automatically signed up if the user doesn't exist. + +```json +{ + "phone": "12345678" // follows the E.164 format + "create_user": true +} +``` + +OR + +```json +// exactly the same as /magiclink +{ + "email": "email@example.com" + "create_user": true +} +``` + +Returns: + +```json +{} +``` + +### **POST /magiclink** (recommended to use /otp instead. See above.) + +Magic Link. Will deliver a link (e.g. `/verify?type=magiclink&token=fgtyuf68ddqdaDd`) to the user based on +email address which they can use to redeem an access_token. + +By default Magic Links can only be sent once every 60 seconds + +```json +{ + "email": "email@example.com" +} +``` + +Returns: + +```json +{} +``` + +when clicked the magic link will redirect the user to `#access_token=x&refresh_token=y&expires_in=z&token_type=bearer&type=magiclink` (see `/verify` above) + +### **POST /recover** + +Password recovery. Will deliver a password recovery mail to the user based on +email address. + +By default recovery links can only be sent once every 60 seconds + +```json +{ + "email": "email@example.com" +} +``` + +Returns: + +```json +{} +``` + +### **POST /token** + +This is an OAuth2 endpoint that currently implements +the password and refresh_token grant types + +query params: + +``` +?grant_type=password +``` + +body: + +```json +// Email login +{ + "email": "name@domain.com", + "password": "somepassword" +} + +// Phone login +{ + "phone": "12345678", + "password": "somepassword" +} +``` + +or + +query params: + +``` +grant_type=refresh_token +``` + +body: + +```json +{ + "refresh_token": "a-refresh-token" +} +``` + +Once you have an access token, you can access the methods requiring authentication +by settings the `Authorization: Bearer YOUR_ACCESS_TOKEN_HERE` header. + +Returns: + +```json +{ + "access_token": "jwt-token-representing-the-user", + "token_type": "bearer", + "expires_in": 3600, + "refresh_token": "a-refresh-token" +} +``` + +### **GET /user** + +Get the JSON object for the logged in user (requires authentication) + +Returns: + +```json +{ + "id": "11111111-2222-3333-4444-5555555555555", + "email": "email@example.com", + "confirmation_sent_at": "2016-05-15T20:49:40.882805774-07:00", + "created_at": "2016-05-15T19:53:12.368652374-07:00", + "updated_at": "2016-05-15T19:53:12.368652374-07:00" +} +``` + +### **PUT /user** + +Update a user (Requires authentication). Apart from changing email/password, this +method can be used to set custom user data. Changing the email will result in a magiclink being sent out. + +```json +{ + "email": "new-email@example.com", + "password": "new-password", + "phone": "+123456789", + "data": { + "key": "value", + "number": 10, + "admin": false + } +} +``` + +Returns: + +```json +{ + "id": "11111111-2222-3333-4444-5555555555555", + "email": "email@example.com", + "email_change_sent_at": "2016-05-15T20:49:40.882805774-07:00", + "phone": "+123456789", + "phone_change_sent_at": "2016-05-15T20:49:40.882805774-07:00", + "created_at": "2016-05-15T19:53:12.368652374-07:00", + "updated_at": "2016-05-15T19:53:12.368652374-07:00" +} +``` + +If `GOTRUE_SECURITY_UPDATE_PASSWORD_REQUIRE_REAUTHENTICATION` is enabled, the user will need to reauthenticate first. + +```json +{ + "password": "new-password", + "nonce": "123456" +} +``` + +### **GET /reauthenticate** + +Sends a nonce to the user's email (preferred) or phone. This endpoint requires the user to be logged in / authenticated first. The user needs to have either an email or phone number for the nonce to be sent successfully. + +```json +headers: { + "Authorization" : "Bearer eyJhbGciOiJI...M3A90LCkxxtX9oNP9KZO" +} +``` + +### **POST /logout** + +Logout a user (Requires authentication). + +This will revoke all refresh tokens for the user. Remember that the JWT tokens +will still be valid for stateless auth until they expires. + +### **GET /authorize** + +Get access_token from external oauth provider + +query params: + +``` +provider=apple | azure | bitbucket | discord | facebook | figma | github | gitlab | google | keycloak | linkedin | notion | slack | spotify | twitch | twitter | workos + +scopes= +``` + +Redirects to provider and then to `/callback` + +For apple specific setup see: + +### **GET /callback** + +External provider should redirect to here + +Redirects to `#access_token=&refresh_token=&provider_token=&expires_in=3600&provider=` +If additional scopes were requested then `provider_token` will be populated, you can use this to fetch additional data from the provider or interact with their services diff --git a/auth_v2.169.0/SECURITY.md b/auth_v2.169.0/SECURITY.md new file mode 100644 index 0000000..c607303 --- /dev/null +++ b/auth_v2.169.0/SECURITY.md @@ -0,0 +1,60 @@ +# Security Policy + +Auth is a project maintained by [Supabase](https://supabase.com). Below is +our security policy. + +Contact: security@supabase.io +Canonical: https://supabase.com/.well-known/security.txt + +At Supabase, we consider the security of our systems a top priority. But no +matter how much effort we put into system security, there can still be +vulnerabilities present. + +If you discover a vulnerability, we would like to know about it so we can take +steps to address it as quickly as possible. We would like to ask you to help us +better protect our clients and our systems. + +Out of scope vulnerabilities: + +- Clickjacking on pages with no sensitive actions. +- Unauthenticated/logout/login CSRF. +- Attacks requiring MITM or physical access to a user's device. +- Any activity that could lead to the disruption of our service (DoS). +- Content spoofing and text injection issues without showing an attack + vector/without being able to modify HTML/CSS. +- Email spoofing +- Missing DNSSEC, CAA, CSP headers +- Lack of Secure or HTTP only flag on non-sensitive cookies +- Deadlinks + +Please do the following: + +- E-mail your findings to security@supabase.io. +- Do not run automated scanners on our infrastructure or dashboard. If you wish + to do this, contact us and we will set up a sandbox for you. +- Do not take advantage of the vulnerability or problem you have discovered, + for example by downloading more data than necessary to demonstrate the + vulnerability or deleting or modifying other people's data, +- Do not reveal the problem to others until it has been resolved, +- Do not use attacks on physical security, social engineering, distributed + denial of service, spam or applications of third parties, and +- Do provide sufficient information to reproduce the problem, so we will be + able to resolve it as quickly as possible. Usually, the IP address or the URL + of the affected system and a description of the vulnerability will be + sufficient, but complex vulnerabilities may require further explanation. + +What we promise: + +- We will respond to your report within 3 business days with our evaluation of + the report and an expected resolution date, +- If you have followed the instructions above, we will not take any legal + action against you in regard to the report, +- We will handle your report with strict confidentiality, and not pass on your + personal details to third parties without your permission, +- We will keep you informed of the progress towards resolving the problem, +- In the public information concerning the problem reported, we will give your + name as the discoverer of the problem (unless you desire otherwise), and + +We strive to resolve all problems as quickly as possible, and we would like to +play an active role in the ultimate publication on the problem after it is +resolved. diff --git a/auth_v2.169.0/app.json b/auth_v2.169.0/app.json new file mode 100644 index 0000000..4868656 --- /dev/null +++ b/auth_v2.169.0/app.json @@ -0,0 +1,34 @@ +{ + "name": "Gotrue", + "description": "", + "website": "https://www.gotrueapi.org", + "repository": "https://github.com/supabase/gotrue", + "env": { + "DATABASE_URL": {}, + "GOTRUE_DB_DRIVER": { + "value": "postgres" + }, + "GOTRUE_DB_AUTOMIGRATE": { + "value": true + }, + "GOTRUE_DB_NAMESPACE": { + "value": "auth" + }, + "GOTRUE_JWT_SECRET": { + "required": true + }, + "GOTRUE_SMTP_ADMIN_EMAIL": {}, + "GOTRUE_SMTP_HOST": {}, + "GOTRUE_SMTP_PASS": {}, + "GOTRUE_SMTP_PORT": {}, + "GOTRUE_MAILER_SITE_URL": {}, + "GOTRUE_MAILER_SUBJECTS_CONFIRMATION": {}, + "GOTRUE_MAILER_SUBJECTS_RECOVERY": {}, + "GOTRUE_MAILER_SUBJECTS_MAGIC_LINK": {}, + "GOTRUE_MAILER_TEMPLATES_CONFIRMATION": {}, + "GOTRUE_MAILER_TEMPLATES_EMAIL_CHANGE": {}, + "GOTRUE_MAILER_TEMPLATES_RECOVERY": {}, + "GOTRUE_MAILER_TEMPLATES_MAGIC_LINK": {}, + "GOTRUE_MAILER_USER": {} + } +} diff --git a/auth_v2.169.0/client/admin/client.go b/auth_v2.169.0/client/admin/client.go new file mode 100644 index 0000000..7abe6e8 --- /dev/null +++ b/auth_v2.169.0/client/admin/client.go @@ -0,0 +1,2674 @@ +// Package admin provides primitives to interact with the openapi HTTP API. +// +// Code generated by github.com/deepmap/oapi-codegen version v1.12.4 DO NOT EDIT. +package admin + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/deepmap/oapi-codegen/pkg/runtime" + openapi_types "github.com/deepmap/oapi-codegen/pkg/types" +) + +const ( + APIKeyAuthScopes = "APIKeyAuth.Scopes" + AdminAuthScopes = "AdminAuth.Scopes" +) + +// Defines values for PostAdminSsoProvidersJSONBodyType. +const ( + Saml PostAdminSsoProvidersJSONBodyType = "saml" +) + +// Defines values for PostGenerateLinkJSONBodyType. +const ( + EmailChangeCurrent PostGenerateLinkJSONBodyType = "email_change_current" + EmailChangeNew PostGenerateLinkJSONBodyType = "email_change_new" + Magiclink PostGenerateLinkJSONBodyType = "magiclink" + Recovery PostGenerateLinkJSONBodyType = "recovery" + Signup PostGenerateLinkJSONBodyType = "signup" +) + +// ErrorSchema defines model for ErrorSchema. +type ErrorSchema struct { + // Code The HTTP status code. Usually missing if `error` is present. + Code *int `json:"code,omitempty"` + + // Error Certain responses will contain this property with the provided values. + // + // Usually one of these: + // - invalid_request + // - unauthorized_client + // - access_denied + // - server_error + // - temporarily_unavailable + // - unsupported_otp_type + Error *string `json:"error,omitempty"` + + // ErrorDescription Certain responses that have an `error` property may have this property which describes the error. + ErrorDescription *string `json:"error_description,omitempty"` + + // Msg A basic message describing the problem with the request. Usually missing if `error` is present. + Msg *string `json:"msg,omitempty"` +} + +// MFAFactorSchema Represents a MFA factor. +type MFAFactorSchema struct { + // FactorType Usually one of: + // - totp + FactorType *string `json:"factor_type,omitempty"` + FriendlyName *string `json:"friendly_name,omitempty"` + Id *openapi_types.UUID `json:"id,omitempty"` + + // Status Usually one of: + // - verified + // - unverified + Status *string `json:"status,omitempty"` +} + +// SAMLAttributeMappingSchema defines model for SAMLAttributeMappingSchema. +type SAMLAttributeMappingSchema struct { + Keys *map[string]interface{} `json:"keys,omitempty"` +} + +// SSOProviderSchema defines model for SSOProviderSchema. +type SSOProviderSchema struct { + Id *openapi_types.UUID `json:"id,omitempty"` + Saml *struct { + AttributeMapping *SAMLAttributeMappingSchema `json:"attribute_mapping,omitempty"` + EntityId *string `json:"entity_id,omitempty"` + MetadataUrl *string `json:"metadata_url,omitempty"` + MetadataXml *string `json:"metadata_xml,omitempty"` + } `json:"saml,omitempty"` + SsoDomains *[]struct { + Domain *string `json:"domain,omitempty"` + } `json:"sso_domains,omitempty"` +} + +// UserSchema Object describing the user related to the issued access and refresh tokens. +type UserSchema struct { + AppMetadata *map[string]interface{} `json:"app_metadata,omitempty"` + Aud *string `json:"aud,omitempty"` + BannedUntil *time.Time `json:"banned_until,omitempty"` + ConfirmationSentAt *time.Time `json:"confirmation_sent_at,omitempty"` + ConfirmedAt *time.Time `json:"confirmed_at,omitempty"` + CreatedAt *time.Time `json:"created_at,omitempty"` + DeletedAt *time.Time `json:"deleted_at,omitempty"` + + // Email User's primary contact email. In most cases you can uniquely identify a user by their email address, but not in all cases. + Email *string `json:"email,omitempty"` + EmailChangeSentAt *time.Time `json:"email_change_sent_at,omitempty"` + EmailConfirmedAt *time.Time `json:"email_confirmed_at,omitempty"` + Factors *[]MFAFactorSchema `json:"factors,omitempty"` + Id *openapi_types.UUID `json:"id,omitempty"` + Identities *[]map[string]interface{} `json:"identities,omitempty"` + LastSignInAt *time.Time `json:"last_sign_in_at,omitempty"` + NewEmail *openapi_types.Email `json:"new_email,omitempty"` + NewPhone *string `json:"new_phone,omitempty"` + + // Phone User's primary contact phone number. In most cases you can uniquely identify a user by their phone number, but not in all cases. + Phone *string `json:"phone,omitempty"` + PhoneChangeSentAt *time.Time `json:"phone_change_sent_at,omitempty"` + PhoneConfirmedAt *time.Time `json:"phone_confirmed_at,omitempty"` + ReauthenticationSentAt *time.Time `json:"reauthentication_sent_at,omitempty"` + RecoverySentAt *time.Time `json:"recovery_sent_at,omitempty"` + Role *string `json:"role,omitempty"` + UpdatedAt *time.Time `json:"updated_at,omitempty"` + UserMetadata *map[string]interface{} `json:"user_metadata,omitempty"` +} + +// BadRequestResponse defines model for BadRequestResponse. +type BadRequestResponse = ErrorSchema + +// ForbiddenResponse defines model for ForbiddenResponse. +type ForbiddenResponse = ErrorSchema + +// UnauthorizedResponse defines model for UnauthorizedResponse. +type UnauthorizedResponse = ErrorSchema + +// GetAdminAuditParams defines parameters for GetAdminAudit. +type GetAdminAuditParams struct { + Page *int `form:"page,omitempty" json:"page,omitempty"` + PerPage *int `form:"per_page,omitempty" json:"per_page,omitempty"` +} + +// PostAdminSsoProvidersJSONBody defines parameters for PostAdminSsoProviders. +type PostAdminSsoProvidersJSONBody struct { + AttributeMapping *SAMLAttributeMappingSchema `json:"attribute_mapping,omitempty"` + Domains *[]string `json:"domains,omitempty"` + MetadataUrl *string `json:"metadata_url,omitempty"` + MetadataXml *string `json:"metadata_xml,omitempty"` + Type PostAdminSsoProvidersJSONBodyType `json:"type"` +} + +// PostAdminSsoProvidersJSONBodyType defines parameters for PostAdminSsoProviders. +type PostAdminSsoProvidersJSONBodyType string + +// PutAdminSsoProvidersSsoProviderIdJSONBody defines parameters for PutAdminSsoProvidersSsoProviderId. +type PutAdminSsoProvidersSsoProviderIdJSONBody struct { + AttributeMapping *SAMLAttributeMappingSchema `json:"attribute_mapping,omitempty"` + Domains *[]string `json:"domains,omitempty"` + MetadataUrl *string `json:"metadata_url,omitempty"` + MetadataXml *string `json:"metadata_xml,omitempty"` +} + +// GetAdminUsersParams defines parameters for GetAdminUsers. +type GetAdminUsersParams struct { + Page *int `form:"page,omitempty" json:"page,omitempty"` + PerPage *int `form:"per_page,omitempty" json:"per_page,omitempty"` +} + +// PutAdminUsersUserIdFactorsFactorIdJSONBody defines parameters for PutAdminUsersUserIdFactorsFactorId. +type PutAdminUsersUserIdFactorsFactorIdJSONBody = map[string]interface{} + +// PostGenerateLinkJSONBody defines parameters for PostGenerateLink. +type PostGenerateLinkJSONBody struct { + Data *map[string]interface{} `json:"data,omitempty"` + Email openapi_types.Email `json:"email"` + NewEmail *openapi_types.Email `json:"new_email,omitempty"` + Password *string `json:"password,omitempty"` + RedirectTo *string `json:"redirect_to,omitempty"` + Type PostGenerateLinkJSONBodyType `json:"type"` +} + +// PostGenerateLinkJSONBodyType defines parameters for PostGenerateLink. +type PostGenerateLinkJSONBodyType string + +// PostInviteJSONBody defines parameters for PostInvite. +type PostInviteJSONBody struct { + Data *map[string]interface{} `json:"data,omitempty"` + Email string `json:"email"` +} + +// PostAdminSsoProvidersJSONRequestBody defines body for PostAdminSsoProviders for application/json ContentType. +type PostAdminSsoProvidersJSONRequestBody PostAdminSsoProvidersJSONBody + +// PutAdminSsoProvidersSsoProviderIdJSONRequestBody defines body for PutAdminSsoProvidersSsoProviderId for application/json ContentType. +type PutAdminSsoProvidersSsoProviderIdJSONRequestBody PutAdminSsoProvidersSsoProviderIdJSONBody + +// PutAdminUsersUserIdJSONRequestBody defines body for PutAdminUsersUserId for application/json ContentType. +type PutAdminUsersUserIdJSONRequestBody = UserSchema + +// PutAdminUsersUserIdFactorsFactorIdJSONRequestBody defines body for PutAdminUsersUserIdFactorsFactorId for application/json ContentType. +type PutAdminUsersUserIdFactorsFactorIdJSONRequestBody = PutAdminUsersUserIdFactorsFactorIdJSONBody + +// PostGenerateLinkJSONRequestBody defines body for PostGenerateLink for application/json ContentType. +type PostGenerateLinkJSONRequestBody PostGenerateLinkJSONBody + +// PostInviteJSONRequestBody defines body for PostInvite for application/json ContentType. +type PostInviteJSONRequestBody PostInviteJSONBody + +// RequestEditorFn is the function signature for the RequestEditor callback function +type RequestEditorFn func(ctx context.Context, req *http.Request) error + +// Doer performs HTTP requests. +// +// The standard http.Client implements this interface. +type HttpRequestDoer interface { + Do(req *http.Request) (*http.Response, error) +} + +// Client which conforms to the OpenAPI3 specification for this service. +type Client struct { + // The endpoint of the server conforming to this interface, with scheme, + // https://api.deepmap.com for example. This can contain a path relative + // to the server, such as https://api.deepmap.com/dev-test, and all the + // paths in the swagger spec will be appended to the server. + Server string + + // Doer for performing requests, typically a *http.Client with any + // customized settings, such as certificate chains. + Client HttpRequestDoer + + // A list of callbacks for modifying requests which are generated before sending over + // the network. + RequestEditors []RequestEditorFn +} + +// ClientOption allows setting custom parameters during construction +type ClientOption func(*Client) error + +// Creates a new Client, with reasonable defaults +func NewClient(server string, opts ...ClientOption) (*Client, error) { + // create a client with sane default values + client := Client{ + Server: server, + } + // mutate client and add all optional params + for _, o := range opts { + if err := o(&client); err != nil { + return nil, err + } + } + // ensure the server URL always has a trailing slash + if !strings.HasSuffix(client.Server, "/") { + client.Server += "/" + } + // create httpClient, if not already present + if client.Client == nil { + client.Client = &http.Client{} + } + return &client, nil +} + +// WithHTTPClient allows overriding the default Doer, which is +// automatically created using http.Client. This is useful for tests. +func WithHTTPClient(doer HttpRequestDoer) ClientOption { + return func(c *Client) error { + c.Client = doer + return nil + } +} + +// WithRequestEditorFn allows setting up a callback function, which will be +// called right before sending the request. This can be used to mutate the request. +func WithRequestEditorFn(fn RequestEditorFn) ClientOption { + return func(c *Client) error { + c.RequestEditors = append(c.RequestEditors, fn) + return nil + } +} + +// The interface specification for the client above. +type ClientInterface interface { + // GetAdminAudit request + GetAdminAudit(ctx context.Context, params *GetAdminAuditParams, reqEditors ...RequestEditorFn) (*http.Response, error) + + // GetAdminSsoProviders request + GetAdminSsoProviders(ctx context.Context, reqEditors ...RequestEditorFn) (*http.Response, error) + + // PostAdminSsoProviders request with any body + PostAdminSsoProvidersWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + + PostAdminSsoProviders(ctx context.Context, body PostAdminSsoProvidersJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) + + // DeleteAdminSsoProvidersSsoProviderId request + DeleteAdminSsoProvidersSsoProviderId(ctx context.Context, ssoProviderId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) + + // GetAdminSsoProvidersSsoProviderId request + GetAdminSsoProvidersSsoProviderId(ctx context.Context, ssoProviderId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) + + // PutAdminSsoProvidersSsoProviderId request with any body + PutAdminSsoProvidersSsoProviderIdWithBody(ctx context.Context, ssoProviderId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + + PutAdminSsoProvidersSsoProviderId(ctx context.Context, ssoProviderId openapi_types.UUID, body PutAdminSsoProvidersSsoProviderIdJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) + + // GetAdminUsers request + GetAdminUsers(ctx context.Context, params *GetAdminUsersParams, reqEditors ...RequestEditorFn) (*http.Response, error) + + // DeleteAdminUsersUserId request + DeleteAdminUsersUserId(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) + + // GetAdminUsersUserId request + GetAdminUsersUserId(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) + + // PutAdminUsersUserId request with any body + PutAdminUsersUserIdWithBody(ctx context.Context, userId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + + PutAdminUsersUserId(ctx context.Context, userId openapi_types.UUID, body PutAdminUsersUserIdJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) + + // GetAdminUsersUserIdFactors request + GetAdminUsersUserIdFactors(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) + + // DeleteAdminUsersUserIdFactorsFactorId request + DeleteAdminUsersUserIdFactorsFactorId(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) + + // PutAdminUsersUserIdFactorsFactorId request with any body + PutAdminUsersUserIdFactorsFactorIdWithBody(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + + PutAdminUsersUserIdFactorsFactorId(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, body PutAdminUsersUserIdFactorsFactorIdJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) + + // PostGenerateLink request with any body + PostGenerateLinkWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + + PostGenerateLink(ctx context.Context, body PostGenerateLinkJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) + + // PostInvite request with any body + PostInviteWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + + PostInvite(ctx context.Context, body PostInviteJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) +} + +func (c *Client) GetAdminAudit(ctx context.Context, params *GetAdminAuditParams, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewGetAdminAuditRequest(c.Server, params) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) GetAdminSsoProviders(ctx context.Context, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewGetAdminSsoProvidersRequest(c.Server) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PostAdminSsoProvidersWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPostAdminSsoProvidersRequestWithBody(c.Server, contentType, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PostAdminSsoProviders(ctx context.Context, body PostAdminSsoProvidersJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPostAdminSsoProvidersRequest(c.Server, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) DeleteAdminSsoProvidersSsoProviderId(ctx context.Context, ssoProviderId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewDeleteAdminSsoProvidersSsoProviderIdRequest(c.Server, ssoProviderId) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) GetAdminSsoProvidersSsoProviderId(ctx context.Context, ssoProviderId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewGetAdminSsoProvidersSsoProviderIdRequest(c.Server, ssoProviderId) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PutAdminSsoProvidersSsoProviderIdWithBody(ctx context.Context, ssoProviderId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPutAdminSsoProvidersSsoProviderIdRequestWithBody(c.Server, ssoProviderId, contentType, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PutAdminSsoProvidersSsoProviderId(ctx context.Context, ssoProviderId openapi_types.UUID, body PutAdminSsoProvidersSsoProviderIdJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPutAdminSsoProvidersSsoProviderIdRequest(c.Server, ssoProviderId, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) GetAdminUsers(ctx context.Context, params *GetAdminUsersParams, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewGetAdminUsersRequest(c.Server, params) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) DeleteAdminUsersUserId(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewDeleteAdminUsersUserIdRequest(c.Server, userId) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) GetAdminUsersUserId(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewGetAdminUsersUserIdRequest(c.Server, userId) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PutAdminUsersUserIdWithBody(ctx context.Context, userId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPutAdminUsersUserIdRequestWithBody(c.Server, userId, contentType, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PutAdminUsersUserId(ctx context.Context, userId openapi_types.UUID, body PutAdminUsersUserIdJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPutAdminUsersUserIdRequest(c.Server, userId, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) GetAdminUsersUserIdFactors(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewGetAdminUsersUserIdFactorsRequest(c.Server, userId) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) DeleteAdminUsersUserIdFactorsFactorId(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewDeleteAdminUsersUserIdFactorsFactorIdRequest(c.Server, userId, factorId) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PutAdminUsersUserIdFactorsFactorIdWithBody(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPutAdminUsersUserIdFactorsFactorIdRequestWithBody(c.Server, userId, factorId, contentType, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PutAdminUsersUserIdFactorsFactorId(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, body PutAdminUsersUserIdFactorsFactorIdJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPutAdminUsersUserIdFactorsFactorIdRequest(c.Server, userId, factorId, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PostGenerateLinkWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPostGenerateLinkRequestWithBody(c.Server, contentType, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PostGenerateLink(ctx context.Context, body PostGenerateLinkJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPostGenerateLinkRequest(c.Server, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PostInviteWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPostInviteRequestWithBody(c.Server, contentType, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) PostInvite(ctx context.Context, body PostInviteJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewPostInviteRequest(c.Server, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +// NewGetAdminAuditRequest generates requests for GetAdminAudit +func NewGetAdminAuditRequest(server string, params *GetAdminAuditParams) (*http.Request, error) { + var err error + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/audit") + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + queryValues := queryURL.Query() + + if params.Page != nil { + + if queryFrag, err := runtime.StyleParamWithLocation("form", true, "page", runtime.ParamLocationQuery, *params.Page); err != nil { + return nil, err + } else if parsed, err := url.ParseQuery(queryFrag); err != nil { + return nil, err + } else { + for k, v := range parsed { + for _, v2 := range v { + queryValues.Add(k, v2) + } + } + } + + } + + if params.PerPage != nil { + + if queryFrag, err := runtime.StyleParamWithLocation("form", true, "per_page", runtime.ParamLocationQuery, *params.PerPage); err != nil { + return nil, err + } else if parsed, err := url.ParseQuery(queryFrag); err != nil { + return nil, err + } else { + for k, v := range parsed { + for _, v2 := range v { + queryValues.Add(k, v2) + } + } + } + + } + + queryURL.RawQuery = queryValues.Encode() + + req, err := http.NewRequest("GET", queryURL.String(), nil) + if err != nil { + return nil, err + } + + return req, nil +} + +// NewGetAdminSsoProvidersRequest generates requests for GetAdminSsoProviders +func NewGetAdminSsoProvidersRequest(server string) (*http.Request, error) { + var err error + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/sso/providers") + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("GET", queryURL.String(), nil) + if err != nil { + return nil, err + } + + return req, nil +} + +// NewPostAdminSsoProvidersRequest calls the generic PostAdminSsoProviders builder with application/json body +func NewPostAdminSsoProvidersRequest(server string, body PostAdminSsoProvidersJSONRequestBody) (*http.Request, error) { + var bodyReader io.Reader + buf, err := json.Marshal(body) + if err != nil { + return nil, err + } + bodyReader = bytes.NewReader(buf) + return NewPostAdminSsoProvidersRequestWithBody(server, "application/json", bodyReader) +} + +// NewPostAdminSsoProvidersRequestWithBody generates requests for PostAdminSsoProviders with any type of body +func NewPostAdminSsoProvidersRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { + var err error + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/sso/providers") + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", queryURL.String(), body) + if err != nil { + return nil, err + } + + req.Header.Add("Content-Type", contentType) + + return req, nil +} + +// NewDeleteAdminSsoProvidersSsoProviderIdRequest generates requests for DeleteAdminSsoProvidersSsoProviderId +func NewDeleteAdminSsoProvidersSsoProviderIdRequest(server string, ssoProviderId openapi_types.UUID) (*http.Request, error) { + var err error + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "ssoProviderId", runtime.ParamLocationPath, ssoProviderId) + if err != nil { + return nil, err + } + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/sso/providers/%s", pathParam0) + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("DELETE", queryURL.String(), nil) + if err != nil { + return nil, err + } + + return req, nil +} + +// NewGetAdminSsoProvidersSsoProviderIdRequest generates requests for GetAdminSsoProvidersSsoProviderId +func NewGetAdminSsoProvidersSsoProviderIdRequest(server string, ssoProviderId openapi_types.UUID) (*http.Request, error) { + var err error + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "ssoProviderId", runtime.ParamLocationPath, ssoProviderId) + if err != nil { + return nil, err + } + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/sso/providers/%s", pathParam0) + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("GET", queryURL.String(), nil) + if err != nil { + return nil, err + } + + return req, nil +} + +// NewPutAdminSsoProvidersSsoProviderIdRequest calls the generic PutAdminSsoProvidersSsoProviderId builder with application/json body +func NewPutAdminSsoProvidersSsoProviderIdRequest(server string, ssoProviderId openapi_types.UUID, body PutAdminSsoProvidersSsoProviderIdJSONRequestBody) (*http.Request, error) { + var bodyReader io.Reader + buf, err := json.Marshal(body) + if err != nil { + return nil, err + } + bodyReader = bytes.NewReader(buf) + return NewPutAdminSsoProvidersSsoProviderIdRequestWithBody(server, ssoProviderId, "application/json", bodyReader) +} + +// NewPutAdminSsoProvidersSsoProviderIdRequestWithBody generates requests for PutAdminSsoProvidersSsoProviderId with any type of body +func NewPutAdminSsoProvidersSsoProviderIdRequestWithBody(server string, ssoProviderId openapi_types.UUID, contentType string, body io.Reader) (*http.Request, error) { + var err error + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "ssoProviderId", runtime.ParamLocationPath, ssoProviderId) + if err != nil { + return nil, err + } + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/sso/providers/%s", pathParam0) + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("PUT", queryURL.String(), body) + if err != nil { + return nil, err + } + + req.Header.Add("Content-Type", contentType) + + return req, nil +} + +// NewGetAdminUsersRequest generates requests for GetAdminUsers +func NewGetAdminUsersRequest(server string, params *GetAdminUsersParams) (*http.Request, error) { + var err error + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/users") + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + queryValues := queryURL.Query() + + if params.Page != nil { + + if queryFrag, err := runtime.StyleParamWithLocation("form", true, "page", runtime.ParamLocationQuery, *params.Page); err != nil { + return nil, err + } else if parsed, err := url.ParseQuery(queryFrag); err != nil { + return nil, err + } else { + for k, v := range parsed { + for _, v2 := range v { + queryValues.Add(k, v2) + } + } + } + + } + + if params.PerPage != nil { + + if queryFrag, err := runtime.StyleParamWithLocation("form", true, "per_page", runtime.ParamLocationQuery, *params.PerPage); err != nil { + return nil, err + } else if parsed, err := url.ParseQuery(queryFrag); err != nil { + return nil, err + } else { + for k, v := range parsed { + for _, v2 := range v { + queryValues.Add(k, v2) + } + } + } + + } + + queryURL.RawQuery = queryValues.Encode() + + req, err := http.NewRequest("GET", queryURL.String(), nil) + if err != nil { + return nil, err + } + + return req, nil +} + +// NewDeleteAdminUsersUserIdRequest generates requests for DeleteAdminUsersUserId +func NewDeleteAdminUsersUserIdRequest(server string, userId openapi_types.UUID) (*http.Request, error) { + var err error + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "userId", runtime.ParamLocationPath, userId) + if err != nil { + return nil, err + } + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/users/%s", pathParam0) + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("DELETE", queryURL.String(), nil) + if err != nil { + return nil, err + } + + return req, nil +} + +// NewGetAdminUsersUserIdRequest generates requests for GetAdminUsersUserId +func NewGetAdminUsersUserIdRequest(server string, userId openapi_types.UUID) (*http.Request, error) { + var err error + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "userId", runtime.ParamLocationPath, userId) + if err != nil { + return nil, err + } + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/users/%s", pathParam0) + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("GET", queryURL.String(), nil) + if err != nil { + return nil, err + } + + return req, nil +} + +// NewPutAdminUsersUserIdRequest calls the generic PutAdminUsersUserId builder with application/json body +func NewPutAdminUsersUserIdRequest(server string, userId openapi_types.UUID, body PutAdminUsersUserIdJSONRequestBody) (*http.Request, error) { + var bodyReader io.Reader + buf, err := json.Marshal(body) + if err != nil { + return nil, err + } + bodyReader = bytes.NewReader(buf) + return NewPutAdminUsersUserIdRequestWithBody(server, userId, "application/json", bodyReader) +} + +// NewPutAdminUsersUserIdRequestWithBody generates requests for PutAdminUsersUserId with any type of body +func NewPutAdminUsersUserIdRequestWithBody(server string, userId openapi_types.UUID, contentType string, body io.Reader) (*http.Request, error) { + var err error + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "userId", runtime.ParamLocationPath, userId) + if err != nil { + return nil, err + } + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/users/%s", pathParam0) + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("PUT", queryURL.String(), body) + if err != nil { + return nil, err + } + + req.Header.Add("Content-Type", contentType) + + return req, nil +} + +// NewGetAdminUsersUserIdFactorsRequest generates requests for GetAdminUsersUserIdFactors +func NewGetAdminUsersUserIdFactorsRequest(server string, userId openapi_types.UUID) (*http.Request, error) { + var err error + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "userId", runtime.ParamLocationPath, userId) + if err != nil { + return nil, err + } + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/users/%s/factors", pathParam0) + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("GET", queryURL.String(), nil) + if err != nil { + return nil, err + } + + return req, nil +} + +// NewDeleteAdminUsersUserIdFactorsFactorIdRequest generates requests for DeleteAdminUsersUserIdFactorsFactorId +func NewDeleteAdminUsersUserIdFactorsFactorIdRequest(server string, userId openapi_types.UUID, factorId openapi_types.UUID) (*http.Request, error) { + var err error + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "userId", runtime.ParamLocationPath, userId) + if err != nil { + return nil, err + } + + var pathParam1 string + + pathParam1, err = runtime.StyleParamWithLocation("simple", false, "factorId", runtime.ParamLocationPath, factorId) + if err != nil { + return nil, err + } + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/users/%s/factors/%s", pathParam0, pathParam1) + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("DELETE", queryURL.String(), nil) + if err != nil { + return nil, err + } + + return req, nil +} + +// NewPutAdminUsersUserIdFactorsFactorIdRequest calls the generic PutAdminUsersUserIdFactorsFactorId builder with application/json body +func NewPutAdminUsersUserIdFactorsFactorIdRequest(server string, userId openapi_types.UUID, factorId openapi_types.UUID, body PutAdminUsersUserIdFactorsFactorIdJSONRequestBody) (*http.Request, error) { + var bodyReader io.Reader + buf, err := json.Marshal(body) + if err != nil { + return nil, err + } + bodyReader = bytes.NewReader(buf) + return NewPutAdminUsersUserIdFactorsFactorIdRequestWithBody(server, userId, factorId, "application/json", bodyReader) +} + +// NewPutAdminUsersUserIdFactorsFactorIdRequestWithBody generates requests for PutAdminUsersUserIdFactorsFactorId with any type of body +func NewPutAdminUsersUserIdFactorsFactorIdRequestWithBody(server string, userId openapi_types.UUID, factorId openapi_types.UUID, contentType string, body io.Reader) (*http.Request, error) { + var err error + + var pathParam0 string + + pathParam0, err = runtime.StyleParamWithLocation("simple", false, "userId", runtime.ParamLocationPath, userId) + if err != nil { + return nil, err + } + + var pathParam1 string + + pathParam1, err = runtime.StyleParamWithLocation("simple", false, "factorId", runtime.ParamLocationPath, factorId) + if err != nil { + return nil, err + } + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/admin/users/%s/factors/%s", pathParam0, pathParam1) + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("PUT", queryURL.String(), body) + if err != nil { + return nil, err + } + + req.Header.Add("Content-Type", contentType) + + return req, nil +} + +// NewPostGenerateLinkRequest calls the generic PostGenerateLink builder with application/json body +func NewPostGenerateLinkRequest(server string, body PostGenerateLinkJSONRequestBody) (*http.Request, error) { + var bodyReader io.Reader + buf, err := json.Marshal(body) + if err != nil { + return nil, err + } + bodyReader = bytes.NewReader(buf) + return NewPostGenerateLinkRequestWithBody(server, "application/json", bodyReader) +} + +// NewPostGenerateLinkRequestWithBody generates requests for PostGenerateLink with any type of body +func NewPostGenerateLinkRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { + var err error + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/generate_link") + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", queryURL.String(), body) + if err != nil { + return nil, err + } + + req.Header.Add("Content-Type", contentType) + + return req, nil +} + +// NewPostInviteRequest calls the generic PostInvite builder with application/json body +func NewPostInviteRequest(server string, body PostInviteJSONRequestBody) (*http.Request, error) { + var bodyReader io.Reader + buf, err := json.Marshal(body) + if err != nil { + return nil, err + } + bodyReader = bytes.NewReader(buf) + return NewPostInviteRequestWithBody(server, "application/json", bodyReader) +} + +// NewPostInviteRequestWithBody generates requests for PostInvite with any type of body +func NewPostInviteRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { + var err error + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/invite") + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", queryURL.String(), body) + if err != nil { + return nil, err + } + + req.Header.Add("Content-Type", contentType) + + return req, nil +} + +func (c *Client) applyEditors(ctx context.Context, req *http.Request, additionalEditors []RequestEditorFn) error { + for _, r := range c.RequestEditors { + if err := r(ctx, req); err != nil { + return err + } + } + for _, r := range additionalEditors { + if err := r(ctx, req); err != nil { + return err + } + } + return nil +} + +// ClientWithResponses builds on ClientInterface to offer response payloads +type ClientWithResponses struct { + ClientInterface +} + +// NewClientWithResponses creates a new ClientWithResponses, which wraps +// Client with return type handling +func NewClientWithResponses(server string, opts ...ClientOption) (*ClientWithResponses, error) { + client, err := NewClient(server, opts...) + if err != nil { + return nil, err + } + return &ClientWithResponses{client}, nil +} + +// WithBaseURL overrides the baseURL. +func WithBaseURL(baseURL string) ClientOption { + return func(c *Client) error { + newBaseURL, err := url.Parse(baseURL) + if err != nil { + return err + } + c.Server = newBaseURL.String() + return nil + } +} + +// ClientWithResponsesInterface is the interface specification for the client with responses above. +type ClientWithResponsesInterface interface { + // GetAdminAudit request + GetAdminAuditWithResponse(ctx context.Context, params *GetAdminAuditParams, reqEditors ...RequestEditorFn) (*GetAdminAuditResponse, error) + + // GetAdminSsoProviders request + GetAdminSsoProvidersWithResponse(ctx context.Context, reqEditors ...RequestEditorFn) (*GetAdminSsoProvidersResponse, error) + + // PostAdminSsoProviders request with any body + PostAdminSsoProvidersWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PostAdminSsoProvidersResponse, error) + + PostAdminSsoProvidersWithResponse(ctx context.Context, body PostAdminSsoProvidersJSONRequestBody, reqEditors ...RequestEditorFn) (*PostAdminSsoProvidersResponse, error) + + // DeleteAdminSsoProvidersSsoProviderId request + DeleteAdminSsoProvidersSsoProviderIdWithResponse(ctx context.Context, ssoProviderId openapi_types.UUID, reqEditors ...RequestEditorFn) (*DeleteAdminSsoProvidersSsoProviderIdResponse, error) + + // GetAdminSsoProvidersSsoProviderId request + GetAdminSsoProvidersSsoProviderIdWithResponse(ctx context.Context, ssoProviderId openapi_types.UUID, reqEditors ...RequestEditorFn) (*GetAdminSsoProvidersSsoProviderIdResponse, error) + + // PutAdminSsoProvidersSsoProviderId request with any body + PutAdminSsoProvidersSsoProviderIdWithBodyWithResponse(ctx context.Context, ssoProviderId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PutAdminSsoProvidersSsoProviderIdResponse, error) + + PutAdminSsoProvidersSsoProviderIdWithResponse(ctx context.Context, ssoProviderId openapi_types.UUID, body PutAdminSsoProvidersSsoProviderIdJSONRequestBody, reqEditors ...RequestEditorFn) (*PutAdminSsoProvidersSsoProviderIdResponse, error) + + // GetAdminUsers request + GetAdminUsersWithResponse(ctx context.Context, params *GetAdminUsersParams, reqEditors ...RequestEditorFn) (*GetAdminUsersResponse, error) + + // DeleteAdminUsersUserId request + DeleteAdminUsersUserIdWithResponse(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*DeleteAdminUsersUserIdResponse, error) + + // GetAdminUsersUserId request + GetAdminUsersUserIdWithResponse(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*GetAdminUsersUserIdResponse, error) + + // PutAdminUsersUserId request with any body + PutAdminUsersUserIdWithBodyWithResponse(ctx context.Context, userId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PutAdminUsersUserIdResponse, error) + + PutAdminUsersUserIdWithResponse(ctx context.Context, userId openapi_types.UUID, body PutAdminUsersUserIdJSONRequestBody, reqEditors ...RequestEditorFn) (*PutAdminUsersUserIdResponse, error) + + // GetAdminUsersUserIdFactors request + GetAdminUsersUserIdFactorsWithResponse(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*GetAdminUsersUserIdFactorsResponse, error) + + // DeleteAdminUsersUserIdFactorsFactorId request + DeleteAdminUsersUserIdFactorsFactorIdWithResponse(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, reqEditors ...RequestEditorFn) (*DeleteAdminUsersUserIdFactorsFactorIdResponse, error) + + // PutAdminUsersUserIdFactorsFactorId request with any body + PutAdminUsersUserIdFactorsFactorIdWithBodyWithResponse(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PutAdminUsersUserIdFactorsFactorIdResponse, error) + + PutAdminUsersUserIdFactorsFactorIdWithResponse(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, body PutAdminUsersUserIdFactorsFactorIdJSONRequestBody, reqEditors ...RequestEditorFn) (*PutAdminUsersUserIdFactorsFactorIdResponse, error) + + // PostGenerateLink request with any body + PostGenerateLinkWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PostGenerateLinkResponse, error) + + PostGenerateLinkWithResponse(ctx context.Context, body PostGenerateLinkJSONRequestBody, reqEditors ...RequestEditorFn) (*PostGenerateLinkResponse, error) + + // PostInvite request with any body + PostInviteWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PostInviteResponse, error) + + PostInviteWithResponse(ctx context.Context, body PostInviteJSONRequestBody, reqEditors ...RequestEditorFn) (*PostInviteResponse, error) +} + +type GetAdminAuditResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *[]struct { + CreatedAt *time.Time `json:"created_at,omitempty"` + Id *openapi_types.UUID `json:"id,omitempty"` + IpAddress *string `json:"ip_address,omitempty"` + Payload *struct { + // Action Usually one of these values: + // - login + // - logout + // - invite_accepted + // - user_signedup + // - user_invited + // - user_deleted + // - user_modified + // - user_recovery_requested + // - user_reauthenticate_requested + // - user_confirmation_requested + // - user_repeated_signup + // - user_updated_password + // - token_revoked + // - token_refreshed + // - generate_recovery_codes + // - factor_in_progress + // - factor_unenrolled + // - challenge_created + // - verification_attempted + // - factor_deleted + // - recovery_codes_deleted + // - factor_updated + // - mfa_code_login + Action *string `json:"action,omitempty"` + ActorId *string `json:"actor_id,omitempty"` + ActorName *string `json:"actor_name,omitempty"` + ActorUsername *string `json:"actor_username,omitempty"` + + // LogType Usually one of these values: + // - account + // - team + // - token + // - user + // - factor + // - recovery_codes + LogType *string `json:"log_type,omitempty"` + Traits *map[string]interface{} `json:"traits,omitempty"` + } `json:"payload,omitempty"` + } + JSON401 *ErrorSchema + JSON403 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r GetAdminAuditResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r GetAdminAuditResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type GetAdminSsoProvidersResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *struct { + Items *[]SSOProviderSchema `json:"items,omitempty"` + } +} + +// Status returns HTTPResponse.Status +func (r GetAdminSsoProvidersResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r GetAdminSsoProvidersResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type PostAdminSsoProvidersResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *SSOProviderSchema + JSON400 *ErrorSchema + JSON401 *ErrorSchema + JSON403 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r PostAdminSsoProvidersResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r PostAdminSsoProvidersResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type DeleteAdminSsoProvidersSsoProviderIdResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *SSOProviderSchema + JSON401 *ErrorSchema + JSON403 *ErrorSchema + JSON404 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r DeleteAdminSsoProvidersSsoProviderIdResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r DeleteAdminSsoProvidersSsoProviderIdResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type GetAdminSsoProvidersSsoProviderIdResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *SSOProviderSchema + JSON401 *ErrorSchema + JSON403 *ErrorSchema + JSON404 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r GetAdminSsoProvidersSsoProviderIdResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r GetAdminSsoProvidersSsoProviderIdResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type PutAdminSsoProvidersSsoProviderIdResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *SSOProviderSchema + JSON400 *ErrorSchema + JSON401 *ErrorSchema + JSON403 *ErrorSchema + JSON404 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r PutAdminSsoProvidersSsoProviderIdResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r PutAdminSsoProvidersSsoProviderIdResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type GetAdminUsersResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *struct { + Aud *string `json:"aud,omitempty"` + Users *[]UserSchema `json:"users,omitempty"` + } + JSON401 *ErrorSchema + JSON403 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r GetAdminUsersResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r GetAdminUsersResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type DeleteAdminUsersUserIdResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *UserSchema + JSON401 *ErrorSchema + JSON403 *ErrorSchema + JSON404 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r DeleteAdminUsersUserIdResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r DeleteAdminUsersUserIdResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type GetAdminUsersUserIdResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *UserSchema + JSON401 *ErrorSchema + JSON403 *ErrorSchema + JSON404 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r GetAdminUsersUserIdResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r GetAdminUsersUserIdResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type PutAdminUsersUserIdResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *UserSchema + JSON401 *ErrorSchema + JSON403 *ErrorSchema + JSON404 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r PutAdminUsersUserIdResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r PutAdminUsersUserIdResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type GetAdminUsersUserIdFactorsResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *[]MFAFactorSchema + JSON401 *ErrorSchema + JSON403 *ErrorSchema + JSON404 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r GetAdminUsersUserIdFactorsResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r GetAdminUsersUserIdFactorsResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type DeleteAdminUsersUserIdFactorsFactorIdResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *MFAFactorSchema + JSON401 *ErrorSchema + JSON403 *ErrorSchema + JSON404 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r DeleteAdminUsersUserIdFactorsFactorIdResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r DeleteAdminUsersUserIdFactorsFactorIdResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type PutAdminUsersUserIdFactorsFactorIdResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *MFAFactorSchema + JSON401 *ErrorSchema + JSON403 *ErrorSchema + JSON404 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r PutAdminUsersUserIdFactorsFactorIdResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r PutAdminUsersUserIdFactorsFactorIdResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type PostGenerateLinkResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *struct { + ActionLink *string `json:"action_link,omitempty"` + EmailOtp *string `json:"email_otp,omitempty"` + HashedToken *string `json:"hashed_token,omitempty"` + RedirectTo *string `json:"redirect_to,omitempty"` + VerificationType *string `json:"verification_type,omitempty"` + AdditionalProperties map[string]interface{} `json:"-"` + } + JSON400 *ErrorSchema + JSON401 *ErrorSchema + JSON403 *ErrorSchema + JSON404 *ErrorSchema + JSON422 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r PostGenerateLinkResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r PostGenerateLinkResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +type PostInviteResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *UserSchema + JSON400 *ErrorSchema + JSON422 *ErrorSchema +} + +// Status returns HTTPResponse.Status +func (r PostInviteResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r PostInviteResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + +// GetAdminAuditWithResponse request returning *GetAdminAuditResponse +func (c *ClientWithResponses) GetAdminAuditWithResponse(ctx context.Context, params *GetAdminAuditParams, reqEditors ...RequestEditorFn) (*GetAdminAuditResponse, error) { + rsp, err := c.GetAdminAudit(ctx, params, reqEditors...) + if err != nil { + return nil, err + } + return ParseGetAdminAuditResponse(rsp) +} + +// GetAdminSsoProvidersWithResponse request returning *GetAdminSsoProvidersResponse +func (c *ClientWithResponses) GetAdminSsoProvidersWithResponse(ctx context.Context, reqEditors ...RequestEditorFn) (*GetAdminSsoProvidersResponse, error) { + rsp, err := c.GetAdminSsoProviders(ctx, reqEditors...) + if err != nil { + return nil, err + } + return ParseGetAdminSsoProvidersResponse(rsp) +} + +// PostAdminSsoProvidersWithBodyWithResponse request with arbitrary body returning *PostAdminSsoProvidersResponse +func (c *ClientWithResponses) PostAdminSsoProvidersWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PostAdminSsoProvidersResponse, error) { + rsp, err := c.PostAdminSsoProvidersWithBody(ctx, contentType, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePostAdminSsoProvidersResponse(rsp) +} + +func (c *ClientWithResponses) PostAdminSsoProvidersWithResponse(ctx context.Context, body PostAdminSsoProvidersJSONRequestBody, reqEditors ...RequestEditorFn) (*PostAdminSsoProvidersResponse, error) { + rsp, err := c.PostAdminSsoProviders(ctx, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePostAdminSsoProvidersResponse(rsp) +} + +// DeleteAdminSsoProvidersSsoProviderIdWithResponse request returning *DeleteAdminSsoProvidersSsoProviderIdResponse +func (c *ClientWithResponses) DeleteAdminSsoProvidersSsoProviderIdWithResponse(ctx context.Context, ssoProviderId openapi_types.UUID, reqEditors ...RequestEditorFn) (*DeleteAdminSsoProvidersSsoProviderIdResponse, error) { + rsp, err := c.DeleteAdminSsoProvidersSsoProviderId(ctx, ssoProviderId, reqEditors...) + if err != nil { + return nil, err + } + return ParseDeleteAdminSsoProvidersSsoProviderIdResponse(rsp) +} + +// GetAdminSsoProvidersSsoProviderIdWithResponse request returning *GetAdminSsoProvidersSsoProviderIdResponse +func (c *ClientWithResponses) GetAdminSsoProvidersSsoProviderIdWithResponse(ctx context.Context, ssoProviderId openapi_types.UUID, reqEditors ...RequestEditorFn) (*GetAdminSsoProvidersSsoProviderIdResponse, error) { + rsp, err := c.GetAdminSsoProvidersSsoProviderId(ctx, ssoProviderId, reqEditors...) + if err != nil { + return nil, err + } + return ParseGetAdminSsoProvidersSsoProviderIdResponse(rsp) +} + +// PutAdminSsoProvidersSsoProviderIdWithBodyWithResponse request with arbitrary body returning *PutAdminSsoProvidersSsoProviderIdResponse +func (c *ClientWithResponses) PutAdminSsoProvidersSsoProviderIdWithBodyWithResponse(ctx context.Context, ssoProviderId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PutAdminSsoProvidersSsoProviderIdResponse, error) { + rsp, err := c.PutAdminSsoProvidersSsoProviderIdWithBody(ctx, ssoProviderId, contentType, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePutAdminSsoProvidersSsoProviderIdResponse(rsp) +} + +func (c *ClientWithResponses) PutAdminSsoProvidersSsoProviderIdWithResponse(ctx context.Context, ssoProviderId openapi_types.UUID, body PutAdminSsoProvidersSsoProviderIdJSONRequestBody, reqEditors ...RequestEditorFn) (*PutAdminSsoProvidersSsoProviderIdResponse, error) { + rsp, err := c.PutAdminSsoProvidersSsoProviderId(ctx, ssoProviderId, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePutAdminSsoProvidersSsoProviderIdResponse(rsp) +} + +// GetAdminUsersWithResponse request returning *GetAdminUsersResponse +func (c *ClientWithResponses) GetAdminUsersWithResponse(ctx context.Context, params *GetAdminUsersParams, reqEditors ...RequestEditorFn) (*GetAdminUsersResponse, error) { + rsp, err := c.GetAdminUsers(ctx, params, reqEditors...) + if err != nil { + return nil, err + } + return ParseGetAdminUsersResponse(rsp) +} + +// DeleteAdminUsersUserIdWithResponse request returning *DeleteAdminUsersUserIdResponse +func (c *ClientWithResponses) DeleteAdminUsersUserIdWithResponse(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*DeleteAdminUsersUserIdResponse, error) { + rsp, err := c.DeleteAdminUsersUserId(ctx, userId, reqEditors...) + if err != nil { + return nil, err + } + return ParseDeleteAdminUsersUserIdResponse(rsp) +} + +// GetAdminUsersUserIdWithResponse request returning *GetAdminUsersUserIdResponse +func (c *ClientWithResponses) GetAdminUsersUserIdWithResponse(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*GetAdminUsersUserIdResponse, error) { + rsp, err := c.GetAdminUsersUserId(ctx, userId, reqEditors...) + if err != nil { + return nil, err + } + return ParseGetAdminUsersUserIdResponse(rsp) +} + +// PutAdminUsersUserIdWithBodyWithResponse request with arbitrary body returning *PutAdminUsersUserIdResponse +func (c *ClientWithResponses) PutAdminUsersUserIdWithBodyWithResponse(ctx context.Context, userId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PutAdminUsersUserIdResponse, error) { + rsp, err := c.PutAdminUsersUserIdWithBody(ctx, userId, contentType, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePutAdminUsersUserIdResponse(rsp) +} + +func (c *ClientWithResponses) PutAdminUsersUserIdWithResponse(ctx context.Context, userId openapi_types.UUID, body PutAdminUsersUserIdJSONRequestBody, reqEditors ...RequestEditorFn) (*PutAdminUsersUserIdResponse, error) { + rsp, err := c.PutAdminUsersUserId(ctx, userId, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePutAdminUsersUserIdResponse(rsp) +} + +// GetAdminUsersUserIdFactorsWithResponse request returning *GetAdminUsersUserIdFactorsResponse +func (c *ClientWithResponses) GetAdminUsersUserIdFactorsWithResponse(ctx context.Context, userId openapi_types.UUID, reqEditors ...RequestEditorFn) (*GetAdminUsersUserIdFactorsResponse, error) { + rsp, err := c.GetAdminUsersUserIdFactors(ctx, userId, reqEditors...) + if err != nil { + return nil, err + } + return ParseGetAdminUsersUserIdFactorsResponse(rsp) +} + +// DeleteAdminUsersUserIdFactorsFactorIdWithResponse request returning *DeleteAdminUsersUserIdFactorsFactorIdResponse +func (c *ClientWithResponses) DeleteAdminUsersUserIdFactorsFactorIdWithResponse(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, reqEditors ...RequestEditorFn) (*DeleteAdminUsersUserIdFactorsFactorIdResponse, error) { + rsp, err := c.DeleteAdminUsersUserIdFactorsFactorId(ctx, userId, factorId, reqEditors...) + if err != nil { + return nil, err + } + return ParseDeleteAdminUsersUserIdFactorsFactorIdResponse(rsp) +} + +// PutAdminUsersUserIdFactorsFactorIdWithBodyWithResponse request with arbitrary body returning *PutAdminUsersUserIdFactorsFactorIdResponse +func (c *ClientWithResponses) PutAdminUsersUserIdFactorsFactorIdWithBodyWithResponse(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PutAdminUsersUserIdFactorsFactorIdResponse, error) { + rsp, err := c.PutAdminUsersUserIdFactorsFactorIdWithBody(ctx, userId, factorId, contentType, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePutAdminUsersUserIdFactorsFactorIdResponse(rsp) +} + +func (c *ClientWithResponses) PutAdminUsersUserIdFactorsFactorIdWithResponse(ctx context.Context, userId openapi_types.UUID, factorId openapi_types.UUID, body PutAdminUsersUserIdFactorsFactorIdJSONRequestBody, reqEditors ...RequestEditorFn) (*PutAdminUsersUserIdFactorsFactorIdResponse, error) { + rsp, err := c.PutAdminUsersUserIdFactorsFactorId(ctx, userId, factorId, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePutAdminUsersUserIdFactorsFactorIdResponse(rsp) +} + +// PostGenerateLinkWithBodyWithResponse request with arbitrary body returning *PostGenerateLinkResponse +func (c *ClientWithResponses) PostGenerateLinkWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PostGenerateLinkResponse, error) { + rsp, err := c.PostGenerateLinkWithBody(ctx, contentType, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePostGenerateLinkResponse(rsp) +} + +func (c *ClientWithResponses) PostGenerateLinkWithResponse(ctx context.Context, body PostGenerateLinkJSONRequestBody, reqEditors ...RequestEditorFn) (*PostGenerateLinkResponse, error) { + rsp, err := c.PostGenerateLink(ctx, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePostGenerateLinkResponse(rsp) +} + +// PostInviteWithBodyWithResponse request with arbitrary body returning *PostInviteResponse +func (c *ClientWithResponses) PostInviteWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*PostInviteResponse, error) { + rsp, err := c.PostInviteWithBody(ctx, contentType, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePostInviteResponse(rsp) +} + +func (c *ClientWithResponses) PostInviteWithResponse(ctx context.Context, body PostInviteJSONRequestBody, reqEditors ...RequestEditorFn) (*PostInviteResponse, error) { + rsp, err := c.PostInvite(ctx, body, reqEditors...) + if err != nil { + return nil, err + } + return ParsePostInviteResponse(rsp) +} + +// ParseGetAdminAuditResponse parses an HTTP response from a GetAdminAuditWithResponse call +func ParseGetAdminAuditResponse(rsp *http.Response) (*GetAdminAuditResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &GetAdminAuditResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest []struct { + CreatedAt *time.Time `json:"created_at,omitempty"` + Id *openapi_types.UUID `json:"id,omitempty"` + IpAddress *string `json:"ip_address,omitempty"` + Payload *struct { + // Action Usually one of these values: + // - login + // - logout + // - invite_accepted + // - user_signedup + // - user_invited + // - user_deleted + // - user_modified + // - user_recovery_requested + // - user_reauthenticate_requested + // - user_confirmation_requested + // - user_repeated_signup + // - user_updated_password + // - token_revoked + // - token_refreshed + // - generate_recovery_codes + // - factor_in_progress + // - factor_unenrolled + // - challenge_created + // - verification_attempted + // - factor_deleted + // - recovery_codes_deleted + // - factor_updated + // - mfa_code_login + Action *string `json:"action,omitempty"` + ActorId *string `json:"actor_id,omitempty"` + ActorName *string `json:"actor_name,omitempty"` + ActorUsername *string `json:"actor_username,omitempty"` + + // LogType Usually one of these values: + // - account + // - team + // - token + // - user + // - factor + // - recovery_codes + LogType *string `json:"log_type,omitempty"` + Traits *map[string]interface{} `json:"traits,omitempty"` + } `json:"payload,omitempty"` + } + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + } + + return response, nil +} + +// ParseGetAdminSsoProvidersResponse parses an HTTP response from a GetAdminSsoProvidersWithResponse call +func ParseGetAdminSsoProvidersResponse(rsp *http.Response) (*GetAdminSsoProvidersResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &GetAdminSsoProvidersResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest struct { + Items *[]SSOProviderSchema `json:"items,omitempty"` + } + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + } + + return response, nil +} + +// ParsePostAdminSsoProvidersResponse parses an HTTP response from a PostAdminSsoProvidersWithResponse call +func ParsePostAdminSsoProvidersResponse(rsp *http.Response) (*PostAdminSsoProvidersResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &PostAdminSsoProvidersResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest SSOProviderSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 400: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON400 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + } + + return response, nil +} + +// ParseDeleteAdminSsoProvidersSsoProviderIdResponse parses an HTTP response from a DeleteAdminSsoProvidersSsoProviderIdWithResponse call +func ParseDeleteAdminSsoProvidersSsoProviderIdResponse(rsp *http.Response) (*DeleteAdminSsoProvidersSsoProviderIdResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &DeleteAdminSsoProvidersSsoProviderIdResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest SSOProviderSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 404: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON404 = &dest + + } + + return response, nil +} + +// ParseGetAdminSsoProvidersSsoProviderIdResponse parses an HTTP response from a GetAdminSsoProvidersSsoProviderIdWithResponse call +func ParseGetAdminSsoProvidersSsoProviderIdResponse(rsp *http.Response) (*GetAdminSsoProvidersSsoProviderIdResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &GetAdminSsoProvidersSsoProviderIdResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest SSOProviderSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 404: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON404 = &dest + + } + + return response, nil +} + +// ParsePutAdminSsoProvidersSsoProviderIdResponse parses an HTTP response from a PutAdminSsoProvidersSsoProviderIdWithResponse call +func ParsePutAdminSsoProvidersSsoProviderIdResponse(rsp *http.Response) (*PutAdminSsoProvidersSsoProviderIdResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &PutAdminSsoProvidersSsoProviderIdResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest SSOProviderSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 400: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON400 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 404: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON404 = &dest + + } + + return response, nil +} + +// ParseGetAdminUsersResponse parses an HTTP response from a GetAdminUsersWithResponse call +func ParseGetAdminUsersResponse(rsp *http.Response) (*GetAdminUsersResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &GetAdminUsersResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest struct { + Aud *string `json:"aud,omitempty"` + Users *[]UserSchema `json:"users,omitempty"` + } + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + } + + return response, nil +} + +// ParseDeleteAdminUsersUserIdResponse parses an HTTP response from a DeleteAdminUsersUserIdWithResponse call +func ParseDeleteAdminUsersUserIdResponse(rsp *http.Response) (*DeleteAdminUsersUserIdResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &DeleteAdminUsersUserIdResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest UserSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 404: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON404 = &dest + + } + + return response, nil +} + +// ParseGetAdminUsersUserIdResponse parses an HTTP response from a GetAdminUsersUserIdWithResponse call +func ParseGetAdminUsersUserIdResponse(rsp *http.Response) (*GetAdminUsersUserIdResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &GetAdminUsersUserIdResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest UserSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 404: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON404 = &dest + + } + + return response, nil +} + +// ParsePutAdminUsersUserIdResponse parses an HTTP response from a PutAdminUsersUserIdWithResponse call +func ParsePutAdminUsersUserIdResponse(rsp *http.Response) (*PutAdminUsersUserIdResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &PutAdminUsersUserIdResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest UserSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 404: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON404 = &dest + + } + + return response, nil +} + +// ParseGetAdminUsersUserIdFactorsResponse parses an HTTP response from a GetAdminUsersUserIdFactorsWithResponse call +func ParseGetAdminUsersUserIdFactorsResponse(rsp *http.Response) (*GetAdminUsersUserIdFactorsResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &GetAdminUsersUserIdFactorsResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest []MFAFactorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 404: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON404 = &dest + + } + + return response, nil +} + +// ParseDeleteAdminUsersUserIdFactorsFactorIdResponse parses an HTTP response from a DeleteAdminUsersUserIdFactorsFactorIdWithResponse call +func ParseDeleteAdminUsersUserIdFactorsFactorIdResponse(rsp *http.Response) (*DeleteAdminUsersUserIdFactorsFactorIdResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &DeleteAdminUsersUserIdFactorsFactorIdResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest MFAFactorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 404: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON404 = &dest + + } + + return response, nil +} + +// ParsePutAdminUsersUserIdFactorsFactorIdResponse parses an HTTP response from a PutAdminUsersUserIdFactorsFactorIdWithResponse call +func ParsePutAdminUsersUserIdFactorsFactorIdResponse(rsp *http.Response) (*PutAdminUsersUserIdFactorsFactorIdResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &PutAdminUsersUserIdFactorsFactorIdResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest MFAFactorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 404: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON404 = &dest + + } + + return response, nil +} + +// ParsePostGenerateLinkResponse parses an HTTP response from a PostGenerateLinkWithResponse call +func ParsePostGenerateLinkResponse(rsp *http.Response) (*PostGenerateLinkResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &PostGenerateLinkResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest struct { + ActionLink *string `json:"action_link,omitempty"` + EmailOtp *string `json:"email_otp,omitempty"` + HashedToken *string `json:"hashed_token,omitempty"` + RedirectTo *string `json:"redirect_to,omitempty"` + VerificationType *string `json:"verification_type,omitempty"` + AdditionalProperties map[string]interface{} `json:"-"` + } + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 400: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON400 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON403 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 404: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON404 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 422: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON422 = &dest + + } + + return response, nil +} + +// ParsePostInviteResponse parses an HTTP response from a PostInviteWithResponse call +func ParsePostInviteResponse(rsp *http.Response) (*PostInviteResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &PostInviteResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest UserSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 400: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON400 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 422: + var dest ErrorSchema + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON422 = &dest + + } + + return response, nil +} diff --git a/auth_v2.169.0/client/admin/gen.go b/auth_v2.169.0/client/admin/gen.go new file mode 100644 index 0000000..c0cf6e3 --- /dev/null +++ b/auth_v2.169.0/client/admin/gen.go @@ -0,0 +1,3 @@ +package admin + +//go:generate oapi-codegen -config ./oapi-codegen.yaml ../../openapi.yaml diff --git a/auth_v2.169.0/client/admin/oapi-codegen.yaml b/auth_v2.169.0/client/admin/oapi-codegen.yaml new file mode 100644 index 0000000..a3aa634 --- /dev/null +++ b/auth_v2.169.0/client/admin/oapi-codegen.yaml @@ -0,0 +1,7 @@ +package: admin +generate: + - client + - types +include-tags: + - admin +output: client.go diff --git a/auth_v2.169.0/cmd/admin_cmd.go b/auth_v2.169.0/cmd/admin_cmd.go new file mode 100644 index 0000000..7997bb5 --- /dev/null +++ b/auth_v2.169.0/cmd/admin_cmd.go @@ -0,0 +1,131 @@ +package cmd + +import ( + "github.com/gofrs/uuid" + "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +var autoconfirm, isAdmin bool +var audience string + +func getAudience(c *conf.GlobalConfiguration) string { + if audience == "" { + return c.JWT.Aud + } + + return audience +} + +func adminCmd() *cobra.Command { + var adminCmd = &cobra.Command{ + Use: "admin", + } + + adminCmd.AddCommand(&adminCreateUserCmd, &adminDeleteUserCmd) + adminCmd.PersistentFlags().StringVarP(&audience, "aud", "a", "", "Set the new user's audience") + + adminCreateUserCmd.Flags().BoolVar(&autoconfirm, "confirm", false, "Automatically confirm user without sending an email") + adminCreateUserCmd.Flags().BoolVar(&isAdmin, "admin", false, "Create user with admin privileges") + + return adminCmd +} + +var adminCreateUserCmd = cobra.Command{ + Use: "createuser", + Run: func(cmd *cobra.Command, args []string) { + if len(args) < 2 { + logrus.Fatal("Not enough arguments to createuser command. Expected at least email and password values") + return + } + + execWithConfigAndArgs(cmd, adminCreateUser, args) + }, +} + +var adminDeleteUserCmd = cobra.Command{ + Use: "deleteuser", + Run: func(cmd *cobra.Command, args []string) { + if len(args) < 1 { + logrus.Fatal("Not enough arguments to deleteuser command. Expected at least ID or email") + return + } + + execWithConfigAndArgs(cmd, adminDeleteUser, args) + }, +} + +func adminCreateUser(config *conf.GlobalConfiguration, args []string) { + db, err := storage.Dial(config) + if err != nil { + logrus.Fatalf("Error opening database: %+v", err) + } + defer db.Close() + + aud := getAudience(config) + if user, err := models.IsDuplicatedEmail(db, args[0], aud, nil); user != nil { + logrus.Fatalf("Error creating new user: user already exists") + } else if err != nil { + logrus.Fatalf("Error checking user email: %+v", err) + } + + user, err := models.NewUser("", args[0], args[1], aud, nil) + if err != nil { + logrus.Fatalf("Error creating new user: %+v", err) + } + + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + if terr = tx.Create(user); terr != nil { + return terr + } + + if len(args) > 2 { + if terr = user.SetRole(tx, args[2]); terr != nil { + return terr + } + } else if isAdmin { + if terr = user.SetRole(tx, config.JWT.AdminGroupName); terr != nil { + return terr + } + } + + if config.Mailer.Autoconfirm || autoconfirm { + if terr = user.Confirm(tx); terr != nil { + return terr + } + } + return nil + }) + if err != nil { + logrus.Fatalf("Unable to create user (%s): %+v", args[0], err) + } + + logrus.Infof("Created user: %s", args[0]) +} + +func adminDeleteUser(config *conf.GlobalConfiguration, args []string) { + db, err := storage.Dial(config) + if err != nil { + logrus.Fatalf("Error opening database: %+v", err) + } + defer db.Close() + + user, err := models.FindUserByEmailAndAudience(db, args[0], getAudience(config)) + if err != nil { + userID := uuid.Must(uuid.FromString(args[0])) + user, err = models.FindUserByID(db, userID) + if err != nil { + logrus.Fatalf("Error finding user (%s): %+v", userID, err) + } + } + + if err = db.Destroy(user); err != nil { + logrus.Fatalf("Error removing user (%s): %+v", args[0], err) + } + + logrus.Infof("Removed user: %s", args[0]) +} diff --git a/auth_v2.169.0/cmd/migrate_cmd.go b/auth_v2.169.0/cmd/migrate_cmd.go new file mode 100644 index 0000000..e0251d6 --- /dev/null +++ b/auth_v2.169.0/cmd/migrate_cmd.go @@ -0,0 +1,117 @@ +package cmd + +import ( + "embed" + "fmt" + "net/url" + "os" + + "github.com/gobuffalo/pop/v6" + "github.com/gobuffalo/pop/v6/logging" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +var EmbeddedMigrations embed.FS + +var migrateCmd = cobra.Command{ + Use: "migrate", + Long: "Migrate database strucutures. This will create new tables and add missing columns and indexes.", + Run: migrate, +} + +func migrate(cmd *cobra.Command, args []string) { + globalConfig := loadGlobalConfig(cmd.Context()) + + if globalConfig.DB.Driver == "" && globalConfig.DB.URL != "" { + u, err := url.Parse(globalConfig.DB.URL) + if err != nil { + logrus.Fatalf("%+v", errors.Wrap(err, "parsing db connection url")) + } + globalConfig.DB.Driver = u.Scheme + } + + log := logrus.StandardLogger() + + pop.Debug = false + if globalConfig.Logging.Level != "" { + level, err := logrus.ParseLevel(globalConfig.Logging.Level) + if err != nil { + log.Fatalf("Failed to parse log level: %+v", err) + } + log.SetLevel(level) + if level == logrus.DebugLevel { + // Set to true to display query info + pop.Debug = true + } + if level != logrus.DebugLevel { + var noopLogger = func(lvl logging.Level, s string, args ...interface{}) { + } + // Hide pop migration logging + pop.SetLogger(noopLogger) + } + } + + u, _ := url.Parse(globalConfig.DB.URL) + processedUrl := globalConfig.DB.URL + if len(u.Query()) != 0 { + processedUrl = fmt.Sprintf("%s&application_name=gotrue_migrations", processedUrl) + } else { + processedUrl = fmt.Sprintf("%s?application_name=gotrue_migrations", processedUrl) + } + deets := &pop.ConnectionDetails{ + Dialect: globalConfig.DB.Driver, + URL: processedUrl, + } + deets.Options = map[string]string{ + "migration_table_name": "schema_migrations", + "Namespace": globalConfig.DB.Namespace, + } + + db, err := pop.NewConnection(deets) + if err != nil { + log.Fatalf("%+v", errors.Wrap(err, "opening db connection")) + } + defer db.Close() + + if err := db.Open(); err != nil { + log.Fatalf("%+v", errors.Wrap(err, "checking database connection")) + } + + log.Debugf("Reading migrations from executable") + box, err := pop.NewMigrationBox(EmbeddedMigrations, db) + if err != nil { + log.Fatalf("%+v", errors.Wrap(err, "creating db migrator")) + } + + mig := box.Migrator + + log.Debugf("before status") + + if log.Level == logrus.DebugLevel { + err = mig.Status(os.Stdout) + if err != nil { + log.Fatalf("%+v", errors.Wrap(err, "migration status")) + } + } + + // turn off schema dump + mig.SchemaPath = "" + + err = mig.Up() + if err != nil { + log.Fatalf("%v", errors.Wrap(err, "running db migrations")) + } else { + log.Infof("GoTrue migrations applied successfully") + } + + log.Debugf("after status") + + if log.Level == logrus.DebugLevel { + err = mig.Status(os.Stdout) + if err != nil { + log.Fatalf("%+v", errors.Wrap(err, "migration status")) + } + } +} diff --git a/auth_v2.169.0/cmd/root_cmd.go b/auth_v2.169.0/cmd/root_cmd.go new file mode 100644 index 0000000..e8783d4 --- /dev/null +++ b/auth_v2.169.0/cmd/root_cmd.go @@ -0,0 +1,63 @@ +package cmd + +import ( + "context" + + "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/observability" +) + +var ( + configFile = "" + watchDir = "" +) + +var rootCmd = cobra.Command{ + Use: "gotrue", + Run: func(cmd *cobra.Command, args []string) { + migrate(cmd, args) + serve(cmd.Context()) + }, +} + +// RootCommand will setup and return the root command +func RootCommand() *cobra.Command { + rootCmd.AddCommand(&serveCmd, &migrateCmd, &versionCmd, adminCmd()) + rootCmd.PersistentFlags().StringVarP(&configFile, "config", "c", "", "base configuration file to load") + rootCmd.PersistentFlags().StringVarP(&watchDir, "config-dir", "d", "", "directory containing a sorted list of config files to watch for changes") + return &rootCmd +} + +func loadGlobalConfig(ctx context.Context) *conf.GlobalConfiguration { + if ctx == nil { + panic("context must not be nil") + } + + config, err := conf.LoadGlobal(configFile) + if err != nil { + logrus.Fatalf("Failed to load configuration: %+v", err) + } + + if err := observability.ConfigureLogging(&config.Logging); err != nil { + logrus.WithError(err).Error("unable to configure logging") + } + + if err := observability.ConfigureTracing(ctx, &config.Tracing); err != nil { + logrus.WithError(err).Error("unable to configure tracing") + } + + if err := observability.ConfigureMetrics(ctx, &config.Metrics); err != nil { + logrus.WithError(err).Error("unable to configure metrics") + } + + if err := observability.ConfigureProfiler(ctx, &config.Profiler); err != nil { + logrus.WithError(err).Error("unable to configure profiler") + } + return config +} + +func execWithConfigAndArgs(cmd *cobra.Command, fn func(config *conf.GlobalConfiguration, args []string), args []string) { + fn(loadGlobalConfig(cmd.Context()), args) +} diff --git a/auth_v2.169.0/cmd/serve_cmd.go b/auth_v2.169.0/cmd/serve_cmd.go new file mode 100644 index 0000000..06fa2f5 --- /dev/null +++ b/auth_v2.169.0/cmd/serve_cmd.go @@ -0,0 +1,111 @@ +package cmd + +import ( + "context" + "net" + "net/http" + "sync" + "time" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "github.com/supabase/auth/internal/api" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/reloader" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" +) + +var serveCmd = cobra.Command{ + Use: "serve", + Long: "Start API server", + Run: func(cmd *cobra.Command, args []string) { + serve(cmd.Context()) + }, +} + +func serve(ctx context.Context) { + if err := conf.LoadFile(configFile); err != nil { + logrus.WithError(err).Fatal("unable to load config") + } + + if err := conf.LoadDirectory(watchDir); err != nil { + logrus.WithError(err).Fatal("unable to load config from watch dir") + } + + config, err := conf.LoadGlobalFromEnv() + if err != nil { + logrus.WithError(err).Fatal("unable to load config") + } + + db, err := storage.Dial(config) + if err != nil { + logrus.Fatalf("error opening database: %+v", err) + } + defer db.Close() + + addr := net.JoinHostPort(config.API.Host, config.API.Port) + logrus.Infof("GoTrue API started on: %s", addr) + + opts := []api.Option{ + api.NewLimiterOptions(config), + } + a := api.NewAPIWithVersion(config, db, utilities.Version, opts...) + ah := reloader.NewAtomicHandler(a) + + baseCtx, baseCancel := context.WithCancel(context.Background()) + defer baseCancel() + + httpSrv := &http.Server{ + Addr: addr, + Handler: ah, + ReadHeaderTimeout: 2 * time.Second, // to mitigate a Slowloris attack + BaseContext: func(net.Listener) context.Context { + return baseCtx + }, + } + log := logrus.WithField("component", "api") + + var wg sync.WaitGroup + defer wg.Wait() // Do not return to caller until this goroutine is done. + + if watchDir != "" { + wg.Add(1) + go func() { + defer wg.Done() + + fn := func(latestCfg *conf.GlobalConfiguration) { + log.Info("reloading api with new configuration") + latestAPI := api.NewAPIWithVersion( + latestCfg, db, utilities.Version, opts...) + ah.Store(latestAPI) + } + + rl := reloader.NewReloader(watchDir) + if err := rl.Watch(ctx, fn); err != nil { + log.WithError(err).Error("watcher is exiting") + } + }() + } + + wg.Add(1) + go func() { + defer wg.Done() + + <-ctx.Done() + + defer baseCancel() // close baseContext + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), time.Minute) + defer shutdownCancel() + + if err := httpSrv.Shutdown(shutdownCtx); err != nil && !errors.Is(err, context.Canceled) { + log.WithError(err).Error("shutdown failed") + } + }() + + if err := httpSrv.ListenAndServe(); err != http.ErrServerClosed { + log.WithError(err).Fatal("http server listen failed") + } +} diff --git a/auth_v2.169.0/cmd/version_cmd.go b/auth_v2.169.0/cmd/version_cmd.go new file mode 100644 index 0000000..cb555d4 --- /dev/null +++ b/auth_v2.169.0/cmd/version_cmd.go @@ -0,0 +1,17 @@ +package cmd + +import ( + "fmt" + + "github.com/spf13/cobra" + "github.com/supabase/auth/internal/utilities" +) + +var versionCmd = cobra.Command{ + Run: showVersion, + Use: "version", +} + +func showVersion(cmd *cobra.Command, args []string) { + fmt.Println(utilities.Version) +} diff --git a/auth_v2.169.0/docker-compose-dev.yml b/auth_v2.169.0/docker-compose-dev.yml new file mode 100644 index 0000000..47ae53d --- /dev/null +++ b/auth_v2.169.0/docker-compose-dev.yml @@ -0,0 +1,34 @@ +version: "3.9" +services: + auth: + container_name: auth + depends_on: + - postgres + build: + context: ./ + dockerfile: Dockerfile.dev + ports: + - '9999:9999' + - '9100:9100' + environment: + - GOTRUE_DB_MIGRATIONS_PATH=/go/src/github.com/supabase/auth/migrations + volumes: + - ./:/go/src/github.com/supabase/auth + command: CompileDaemon --build="make build" --directory=/go/src/github.com/supabase/auth --recursive=true -pattern="(.+\.go|.+\.env)" -exclude=auth -exclude=auth-arm64 -exclude=.env --command="/go/src/github.com/supabase/auth/auth -c=.env.docker" + postgres: + build: + context: . + dockerfile: Dockerfile.postgres.dev + container_name: auth_postgres + ports: + - '5432:5432' + volumes: + - postgres_data:/var/lib/postgresql/data + environment: + - POSTGRES_USER=postgres + - POSTGRES_PASSWORD=root + - POSTGRES_DB=postgres + # sets the schema name, this should match the `NAMESPACE` env var set in your .env file + - DB_NAMESPACE=auth +volumes: + postgres_data: diff --git a/auth_v2.169.0/docs/admin.go b/auth_v2.169.0/docs/admin.go new file mode 100644 index 0000000..5c89222 --- /dev/null +++ b/auth_v2.169.0/docs/admin.go @@ -0,0 +1,106 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +import ( + "github.com/supabase/auth/internal/api" +) + +// swagger:route GET /admin/users admin admin-list-users +// List all users. +// security: +// - bearer: +// responses: +// 200: adminListUserResponse +// 401: unauthorizedError + +// The list of users. +// swagger:response adminListUserResponse +type adminListUserResponseWrapper struct { + // in:body + Body api.AdminListUsersResponse +} + +// swagger:route POST /admin/users admin admin-create-user +// Returns the created user. +// security: +// - bearer: +// responses: +// 200: userResponse +// 401: unauthorizedError + +// The user to be created. +// swagger:parameters admin-create-user +type adminUserParamsWrapper struct { + // in:body + Body api.AdminUserParams +} + +// swagger:route GET /admin/user/{user_id} admin admin-get-user +// Get a user. +// security: +// - bearer: +// parameters: +// + name: user_id +// in: path +// description: The user's id +// required: true +// responses: +// 200: userResponse +// 401: unauthorizedError + +// The user specified. +// swagger:response userResponse + +// swagger:route PUT /admin/user/{user_id} admin admin-update-user +// Update a user. +// security: +// - bearer: +// parameters: +// + name: user_id +// in: path +// description: The user's id +// required: true +// responses: +// 200: userResponse +// 401: unauthorizedError + +// The updated user. +// swagger:response userResponse + +// swagger:route DELETE /admin/user/{user_id} admin admin-delete-user +// Deletes a user. +// security: +// - bearer: +// parameters: +// + name: user_id +// in: path +// description: The user's id +// required: true +// responses: +// 200: deleteUserResponse +// 401: unauthorizedError + +// The updated user. +// swagger:response deleteUserResponse +type deleteUserResponseWrapper struct{} + +// swagger:route POST /admin/generate_link admin admin-generate-link +// Generates an email action link. +// security: +// - bearer: +// responses: +// 200: generateLinkResponse +// 401: unauthorizedError + +// swagger:parameters admin-generate-link +type generateLinkParams struct { + // in:body + Body api.GenerateLinkParams +} + +// The response object for generate link. +// swagger:response generateLinkResponse +type generateLinkResponseWrapper struct { + // in:body + Body api.GenerateLinkResponse +} diff --git a/auth_v2.169.0/docs/doc.go b/auth_v2.169.0/docs/doc.go new file mode 100644 index 0000000..5d5a148 --- /dev/null +++ b/auth_v2.169.0/docs/doc.go @@ -0,0 +1,20 @@ +// Package classification gotrue +// +// Documentation of the gotrue API. +// +// Schemes: http, https +// BasePath: / +// Version: 1.0.0 +// Host: localhost:9999 +// +// SecurityDefinitions: +// bearer: +// type: apiKey +// name: Authentication +// in: header +// +// Produces: +// - application/json +// +// swagger:meta +package docs diff --git a/auth_v2.169.0/docs/errors.go b/auth_v2.169.0/docs/errors.go new file mode 100644 index 0000000..a406445 --- /dev/null +++ b/auth_v2.169.0/docs/errors.go @@ -0,0 +1,6 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +// This endpoint requires a bearer token. +// swagger:response unauthorizedError +type unauthorizedError struct{} diff --git a/auth_v2.169.0/docs/health.go b/auth_v2.169.0/docs/health.go new file mode 100644 index 0000000..3034fad --- /dev/null +++ b/auth_v2.169.0/docs/health.go @@ -0,0 +1,15 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +import "github.com/supabase/auth/internal/api" + +// swagger:route GET /health health health +// The healthcheck endpoint for gotrue. Returns the current gotrue version. +// responses: +// 200: healthCheckResponse + +// swagger:response healthCheckResponse +type healthCheckResponseWrapper struct { + // in:body + Body api.HealthCheckResponse +} diff --git a/auth_v2.169.0/docs/invite.go b/auth_v2.169.0/docs/invite.go new file mode 100644 index 0000000..2775cfb --- /dev/null +++ b/auth_v2.169.0/docs/invite.go @@ -0,0 +1,18 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +import "github.com/supabase/auth/internal/api" + +// swagger:route POST /invite invite invite +// Sends an invite link to the user. +// responses: +// 200: inviteResponse + +// swagger:parameters invite +type inviteParamsWrapper struct { + // in:body + Body api.InviteParams +} + +// swagger:response inviteResponse +type inviteResponseWrapper struct{} diff --git a/auth_v2.169.0/docs/logout.go b/auth_v2.169.0/docs/logout.go new file mode 100644 index 0000000..f0b1125 --- /dev/null +++ b/auth_v2.169.0/docs/logout.go @@ -0,0 +1,12 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +// swagger:route POST /logout logout logout +// Logs out the user. +// security: +// - bearer: +// responses: +// 204: logoutResponse + +// swagger:response logoutResponse +type logoutResponseWrapper struct{} diff --git a/auth_v2.169.0/docs/oauth.go b/auth_v2.169.0/docs/oauth.go new file mode 100644 index 0000000..9b3c223 --- /dev/null +++ b/auth_v2.169.0/docs/oauth.go @@ -0,0 +1,25 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +// swagger:route GET /authorize oauth authorize +// Redirects the user to the 3rd-party OAuth provider to start the OAuth1.0 or OAuth2.0 authentication process. +// parameters: +// + name: redirect_to +// in: query +// description: The redirect url to return the user to after the `/callback` endpoint has completed. +// required: false +// responses: +// 302: authorizeResponse + +// Redirects user to the 3rd-party OAuth provider +// swagger:response authorizeResponse +type authorizeResponseWrapper struct{} + +// swagger:route GET /callback oauth callback +// Receives the redirect from an external provider during the OAuth authentication process. Starts the process of creating an access and refresh token. +// responses: +// 302: callbackResponse + +// Redirects user to the redirect url specified in `/authorize`. If no `redirect_url` is provided, the user will be redirected to the `SITE_URL`. +// swagger:response callbackResponse +type callbackResponseWrapper struct{} diff --git a/auth_v2.169.0/docs/otp.go b/auth_v2.169.0/docs/otp.go new file mode 100644 index 0000000..a62fa07 --- /dev/null +++ b/auth_v2.169.0/docs/otp.go @@ -0,0 +1,19 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +import "github.com/supabase/auth/internal/api" + +// swagger:route POST /otp otp otp +// Passwordless sign-in method for email or phone. +// responses: +// 200: otpResponse + +// swagger:parameters otp +type otpParamsWrapper struct { + // Only an email or phone should be provided. + // in:body + Body api.OtpParams +} + +// swagger:response otpResponse +type otpResponseWrapper struct{} diff --git a/auth_v2.169.0/docs/recover.go b/auth_v2.169.0/docs/recover.go new file mode 100644 index 0000000..1bd249a --- /dev/null +++ b/auth_v2.169.0/docs/recover.go @@ -0,0 +1,18 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +import "github.com/supabase/auth/internal/api" + +// swagger:route POST /recover recovery recovery +// Sends a password recovery email link to the user's email. +// responses: +// 200: recoveryResponse + +// swagger:parameters recovery +type recoveryParamsWrapper struct { + // in:body + Body api.RecoverParams +} + +// swagger:response recoveryResponse +type recoveryResponseWrapper struct{} diff --git a/auth_v2.169.0/docs/settings.go b/auth_v2.169.0/docs/settings.go new file mode 100644 index 0000000..ff5d4ed --- /dev/null +++ b/auth_v2.169.0/docs/settings.go @@ -0,0 +1,15 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +import "github.com/supabase/auth/internal/api" + +// swagger:route GET /settings settings settings +// Returns the configuration settings for the gotrue server. +// responses: +// 200: settingsResponse + +// swagger:response settingsResponse +type settingsResponseWrapper struct { + // in:body + Body api.Settings +} diff --git a/auth_v2.169.0/docs/signup.go b/auth_v2.169.0/docs/signup.go new file mode 100644 index 0000000..a69f015 --- /dev/null +++ b/auth_v2.169.0/docs/signup.go @@ -0,0 +1,17 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +import ( + "github.com/supabase/auth/internal/api" +) + +// swagger:route POST /signup signup signup +// Password-based signup with either email or phone. +// responses: +// 200: userResponse + +// swagger:parameters signup +type signupParamsWrapper struct { + // in:body + Body api.SignupParams +} diff --git a/auth_v2.169.0/docs/token.go b/auth_v2.169.0/docs/token.go new file mode 100644 index 0000000..b4ae542 --- /dev/null +++ b/auth_v2.169.0/docs/token.go @@ -0,0 +1,34 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +import ( + "github.com/supabase/auth/internal/api" +) + +// swagger:route POST /token?grant_type=password token token-password +// Signs in a user with a password. +// responses: +// 200: tokenResponse + +// swagger:parameters token-password +type tokenPasswordGrantParamsWrapper struct { + // in:body + Body api.PasswordGrantParams +} + +// swagger:route POST /token?grant_type=refresh_token token token-refresh +// Refreshes a user's refresh token. +// responses: +// 200: tokenResponse + +// swagger:parameters token-refresh +type tokenRefreshTokenGrantParamsWrapper struct { + // in:body + Body api.RefreshTokenGrantParams +} + +// swagger:response tokenResponse +type tokenResponseWrapper struct { + // in:body + Body api.AccessTokenResponse +} diff --git a/auth_v2.169.0/docs/user.go b/auth_v2.169.0/docs/user.go new file mode 100644 index 0000000..464abfc --- /dev/null +++ b/auth_v2.169.0/docs/user.go @@ -0,0 +1,37 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +import ( + "github.com/supabase/auth/internal/api" + "github.com/supabase/auth/internal/models" +) + +// swagger:route GET /user user user-get +// Get information for the logged-in user. +// security: +// - bearer: +// responses: +// 200: userResponse +// 401: unauthorizedError + +// The current user. +// swagger:response userResponse +type userResponseWrapper struct { + // in:body + Body models.User +} + +// swagger:route PUT /user user user-put +// Returns the updated user. +// security: +// - bearer: +// responses: +// 200: userResponse +// 401: unauthorizedError + +// The current user. +// swagger:parameters user-put +type userUpdateParams struct { + // in:body + Body api.UserUpdateParams +} diff --git a/auth_v2.169.0/docs/verify.go b/auth_v2.169.0/docs/verify.go new file mode 100644 index 0000000..3590650 --- /dev/null +++ b/auth_v2.169.0/docs/verify.go @@ -0,0 +1,24 @@ +//lint:file-ignore U1000 ignore go-swagger template +package docs + +import ( + "github.com/supabase/auth/internal/api" +) + +// swagger:route GET /verify verify verify-get +// Verifies a sign up. + +// swagger:parameters verify-get +type verifyGetParamsWrapper struct { + // in:query + api.VerifyParams +} + +// swagger:route POST /verify verify verify-post +// Verifies a sign up. + +// swagger:parameters verify-post +type verifyPostParamsWrapper struct { + // in:body + Body api.VerifyParams +} diff --git a/auth_v2.169.0/example.docker.env b/auth_v2.169.0/example.docker.env new file mode 100644 index 0000000..477a5d1 --- /dev/null +++ b/auth_v2.169.0/example.docker.env @@ -0,0 +1,8 @@ +GOTRUE_SITE_URL="http://localhost:3000" +GOTRUE_JWT_SECRET="" +GOTRUE_DB_MIGRATIONS_PATH=/go/src/github.com/supabase/auth/migrations +GOTRUE_DB_DRIVER=postgres +DATABASE_URL=postgres://supabase_auth_admin:root@postgres:5432/postgres +GOTRUE_API_HOST=0.0.0.0 +API_EXTERNAL_URL="http://localhost:9999" +PORT=9999 diff --git a/auth_v2.169.0/example.env b/auth_v2.169.0/example.env new file mode 100644 index 0000000..e645c96 --- /dev/null +++ b/auth_v2.169.0/example.env @@ -0,0 +1,238 @@ +# General Config +# NOTE: The service_role key is required as an authorization header for /admin endpoints + +GOTRUE_JWT_SECRET="CHANGE-THIS! VERY IMPORTANT!" +GOTRUE_JWT_EXP="3600" +GOTRUE_JWT_AUD="authenticated" +GOTRUE_JWT_DEFAULT_GROUP_NAME="authenticated" +GOTRUE_JWT_ADMIN_ROLES="supabase_admin,service_role" + +# Database & API connection details +GOTRUE_DB_DRIVER="postgres" +DB_NAMESPACE="auth" +DATABASE_URL="postgres://supabase_auth_admin:root@localhost:5432/postgres" +API_EXTERNAL_URL="http://localhost:9999" +GOTRUE_API_HOST="localhost" +PORT="9999" + +# SMTP config (generate credentials for signup to work) +GOTRUE_SMTP_HOST="" +GOTRUE_SMTP_PORT="" +GOTRUE_SMTP_USER="" +GOTRUE_SMTP_MAX_FREQUENCY="5s" +GOTRUE_SMTP_PASS="" +GOTRUE_SMTP_ADMIN_EMAIL="" +GOTRUE_SMTP_SENDER_NAME="" + +# Mailer config +GOTRUE_MAILER_AUTOCONFIRM="true" +GOTRUE_MAILER_URLPATHS_CONFIRMATION="/verify" +GOTRUE_MAILER_URLPATHS_INVITE="/verify" +GOTRUE_MAILER_URLPATHS_RECOVERY="/verify" +GOTRUE_MAILER_URLPATHS_EMAIL_CHANGE="/verify" +GOTRUE_MAILER_SUBJECTS_CONFIRMATION="Confirm Your Email" +GOTRUE_MAILER_SUBJECTS_RECOVERY="Reset Your Password" +GOTRUE_MAILER_SUBJECTS_MAGIC_LINK="Your Magic Link" +GOTRUE_MAILER_SUBJECTS_EMAIL_CHANGE="Confirm Email Change" +GOTRUE_MAILER_SUBJECTS_INVITE="You have been invited" +GOTRUE_MAILER_SECURE_EMAIL_CHANGE_ENABLED="true" + +# Custom mailer template config +GOTRUE_MAILER_TEMPLATES_INVITE="" +GOTRUE_MAILER_TEMPLATES_CONFIRMATION="" +GOTRUE_MAILER_TEMPLATES_RECOVERY="" +GOTRUE_MAILER_TEMPLATES_MAGIC_LINK="" +GOTRUE_MAILER_TEMPLATES_EMAIL_CHANGE="" + +# Signup config +GOTRUE_DISABLE_SIGNUP="false" +GOTRUE_SITE_URL="http://localhost:3000" +GOTRUE_EXTERNAL_EMAIL_ENABLED="true" +GOTRUE_EXTERNAL_PHONE_ENABLED="true" +GOTRUE_EXTERNAL_IOS_BUNDLE_ID="com.supabase.auth" + +# Whitelist redirect to URLs here, a comma separated list of URIs (e.g. "https://foo.example.com,https://*.foo.example.com,https://bar.example.com") +GOTRUE_URI_ALLOW_LIST="http://localhost:3000" + +# Apple OAuth config +GOTRUE_EXTERNAL_APPLE_ENABLED="false" +GOTRUE_EXTERNAL_APPLE_CLIENT_ID="" +GOTRUE_EXTERNAL_APPLE_SECRET="" +GOTRUE_EXTERNAL_APPLE_REDIRECT_URI="http://localhost:9999/callback" + +# Azure OAuth config +GOTRUE_EXTERNAL_AZURE_ENABLED="false" +GOTRUE_EXTERNAL_AZURE_CLIENT_ID="" +GOTRUE_EXTERNAL_AZURE_SECRET="" +GOTRUE_EXTERNAL_AZURE_REDIRECT_URI="https://localhost:9999/callback" + +# Bitbucket OAuth config +GOTRUE_EXTERNAL_BITBUCKET_ENABLED="false" +GOTRUE_EXTERNAL_BITBUCKET_CLIENT_ID="" +GOTRUE_EXTERNAL_BITBUCKET_SECRET="" +GOTRUE_EXTERNAL_BITBUCKET_REDIRECT_URI="http://localhost:9999/callback" + +# Discord OAuth config +GOTRUE_EXTERNAL_DISCORD_ENABLED="false" +GOTRUE_EXTERNAL_DISCORD_CLIENT_ID="" +GOTRUE_EXTERNAL_DISCORD_SECRET="" +GOTRUE_EXTERNAL_DISCORD_REDIRECT_URI="https://localhost:9999/callback" + +# Facebook OAuth config +GOTRUE_EXTERNAL_FACEBOOK_ENABLED="false" +GOTRUE_EXTERNAL_FACEBOOK_CLIENT_ID="" +GOTRUE_EXTERNAL_FACEBOOK_SECRET="" +GOTRUE_EXTERNAL_FACEBOOK_REDIRECT_URI="https://localhost:9999/callback" + +# Figma OAuth config +GOTRUE_EXTERNAL_FIGMA_ENABLED="false" +GOTRUE_EXTERNAL_FIGMA_CLIENT_ID="" +GOTRUE_EXTERNAL_FIGMA_SECRET="" +GOTRUE_EXTERNAL_FIGMA_REDIRECT_URI="https://localhost:9999/callback" + +# Gitlab OAuth config +GOTRUE_EXTERNAL_GITLAB_ENABLED="false" +GOTRUE_EXTERNAL_GITLAB_CLIENT_ID="" +GOTRUE_EXTERNAL_GITLAB_SECRET="" +GOTRUE_EXTERNAL_GITLAB_REDIRECT_URI="http://localhost:9999/callback" + +# Google OAuth config +GOTRUE_EXTERNAL_GOOGLE_ENABLED="false" +GOTRUE_EXTERNAL_GOOGLE_CLIENT_ID="" +GOTRUE_EXTERNAL_GOOGLE_SECRET="" +GOTRUE_EXTERNAL_GOOGLE_REDIRECT_URI="http://localhost:9999/callback" + +# Github OAuth config +GOTRUE_EXTERNAL_GITHUB_ENABLED="false" +GOTRUE_EXTERNAL_GITHUB_CLIENT_ID="" +GOTRUE_EXTERNAL_GITHUB_SECRET="" +GOTRUE_EXTERNAL_GITHUB_REDIRECT_URI="http://localhost:9999/callback" + +# Kakao OAuth config +GOTRUE_EXTERNAL_KAKAO_ENABLED="false" +GOTRUE_EXTERNAL_KAKAO_CLIENT_ID="" +GOTRUE_EXTERNAL_KAKAO_SECRET="" +GOTRUE_EXTERNAL_KAKAO_REDIRECT_URI="http://localhost:9999/callback" + +# Notion OAuth config +GOTRUE_EXTERNAL_NOTION_ENABLED="false" +GOTRUE_EXTERNAL_NOTION_CLIENT_ID="" +GOTRUE_EXTERNAL_NOTION_SECRET="" +GOTRUE_EXTERNAL_NOTION_REDIRECT_URI="https://localhost:9999/callback" + +# Twitter OAuth1 config +GOTRUE_EXTERNAL_TWITTER_ENABLED="false" +GOTRUE_EXTERNAL_TWITTER_CLIENT_ID="" +GOTRUE_EXTERNAL_TWITTER_SECRET="" +GOTRUE_EXTERNAL_TWITTER_REDIRECT_URI="http://localhost:9999/callback" + +# Twitch OAuth config +GOTRUE_EXTERNAL_TWITCH_ENABLED="false" +GOTRUE_EXTERNAL_TWITCH_CLIENT_ID="" +GOTRUE_EXTERNAL_TWITCH_SECRET="" +GOTRUE_EXTERNAL_TWITCH_REDIRECT_URI="http://localhost:9999/callback" + +# Spotify OAuth config +GOTRUE_EXTERNAL_SPOTIFY_ENABLED="false" +GOTRUE_EXTERNAL_SPOTIFY_CLIENT_ID="" +GOTRUE_EXTERNAL_SPOTIFY_SECRET="" +GOTRUE_EXTERNAL_SPOTIFY_REDIRECT_URI="http://localhost:9999/callback" + +# Keycloak OAuth config +GOTRUE_EXTERNAL_KEYCLOAK_ENABLED="false" +GOTRUE_EXTERNAL_KEYCLOAK_CLIENT_ID="" +GOTRUE_EXTERNAL_KEYCLOAK_SECRET="" +GOTRUE_EXTERNAL_KEYCLOAK_REDIRECT_URI="http://localhost:9999/callback" +GOTRUE_EXTERNAL_KEYCLOAK_URL="https://keycloak.example.com/auth/realms/myrealm" + +# Linkedin OAuth config +GOTRUE_EXTERNAL_LINKEDIN_ENABLED="true" +GOTRUE_EXTERNAL_LINKEDIN_CLIENT_ID="" +GOTRUE_EXTERNAL_LINKEDIN_SECRET="" + +# Slack OAuth config +GOTRUE_EXTERNAL_SLACK_ENABLED="false" +GOTRUE_EXTERNAL_SLACK_CLIENT_ID="" +GOTRUE_EXTERNAL_SLACK_SECRET="" +GOTRUE_EXTERNAL_SLACK_REDIRECT_URI="http://localhost:9999/callback" + +# WorkOS OAuth config +GOTRUE_EXTERNAL_WORKOS_ENABLED="true" +GOTRUE_EXTERNAL_WORKOS_CLIENT_ID="" +GOTRUE_EXTERNAL_WORKOS_SECRET="" +GOTRUE_EXTERNAL_WORKOS_REDIRECT_URI="http://localhost:9999/callback" + +# Zoom OAuth config +GOTRUE_EXTERNAL_ZOOM_ENABLED="false" +GOTRUE_EXTERNAL_ZOOM_CLIENT_ID="" +GOTRUE_EXTERNAL_ZOOM_SECRET="" +GOTRUE_EXTERNAL_ZOOM_REDIRECT_URI="http://localhost:9999/callback" + +# Anonymous auth config +GOTRUE_EXTERNAL_ANONYMOUS_USERS_ENABLED="false" + +# PKCE Config +GOTRUE_EXTERNAL_FLOW_STATE_EXPIRY_DURATION="300s" + +# Phone provider config +GOTRUE_SMS_AUTOCONFIRM="false" +GOTRUE_SMS_MAX_FREQUENCY="5s" +GOTRUE_SMS_OTP_EXP="6000" +GOTRUE_SMS_OTP_LENGTH="6" +GOTRUE_SMS_PROVIDER="twilio" +GOTRUE_SMS_TWILIO_ACCOUNT_SID="" +GOTRUE_SMS_TWILIO_AUTH_TOKEN="" +GOTRUE_SMS_TWILIO_MESSAGE_SERVICE_SID="" +GOTRUE_SMS_TEMPLATE="This is from supabase. Your code is {{ .Code }} ." +GOTRUE_SMS_MESSAGEBIRD_ACCESS_KEY="" +GOTRUE_SMS_MESSAGEBIRD_ORIGINATOR="" +GOTRUE_SMS_TEXTLOCAL_API_KEY="" +GOTRUE_SMS_TEXTLOCAL_SENDER="" +GOTRUE_SMS_VONAGE_API_KEY="" +GOTRUE_SMS_VONAGE_API_SECRET="" +GOTRUE_SMS_VONAGE_FROM="" + +# Captcha config +GOTRUE_SECURITY_CAPTCHA_ENABLED="false" +GOTRUE_SECURITY_CAPTCHA_PROVIDER="hcaptcha" +GOTRUE_SECURITY_CAPTCHA_SECRET="0x0000000000000000000000000000000000000000" +GOTRUE_SECURITY_CAPTCHA_TIMEOUT="10s" +GOTRUE_SESSION_KEY="" + +# SAML config +GOTRUE_EXTERNAL_SAML_ENABLED="true" +GOTRUE_EXTERNAL_SAML_METADATA_URL="" +GOTRUE_EXTERNAL_SAML_API_BASE="http://localhost:9999" +GOTRUE_EXTERNAL_SAML_NAME="auth0" +GOTRUE_EXTERNAL_SAML_SIGNING_CERT="" +GOTRUE_EXTERNAL_SAML_SIGNING_KEY="" + +# Additional Security config +GOTRUE_LOG_LEVEL="debug" +GOTRUE_SECURITY_REFRESH_TOKEN_ROTATION_ENABLED="false" +GOTRUE_SECURITY_REFRESH_TOKEN_REUSE_INTERVAL="0" +GOTRUE_SECURITY_UPDATE_PASSWORD_REQUIRE_REAUTHENTICATION="false" +GOTRUE_OPERATOR_TOKEN="unused-operator-token" +GOTRUE_RATE_LIMIT_HEADER="X-Forwarded-For" +GOTRUE_RATE_LIMIT_EMAIL_SENT="100" + +GOTRUE_MAX_VERIFIED_FACTORS=10 + +# Auth Hook Configuration +GOTRUE_HOOK_CUSTOM_ACCESS_TOKEN_ENABLED=false +GOTRUE_HOOK_CUSTOM_ACCESS_TOKEN_URI="" +# Only for HTTPS Hooks +GOTRUE_HOOK_CUSTOM_ACCESS_TOKEN_SECRET="" + +GOTRUE_HOOK_CUSTOM_SMS_PROVIDER_ENABLED=false +GOTRUE_HOOK_CUSTOM_SMS_PROVIDER_URI="" +# Only for HTTPS Hooks +GOTRUE_HOOK_CUSTOM_SMS_PROVIDER_SECRET="" + + +# Test OTP Config +GOTRUE_SMS_TEST_OTP=":, :..." +GOTRUE_SMS_TEST_OTP_VALID_UNTIL="" # (e.g. 2023-09-29T08:14:06Z) + +GOTRUE_MFA_WEB_AUTHN_ENROLL_ENABLED="false" +GOTRUE_MFA_WEB_AUTHN_VERIFY_ENABLED="false" diff --git a/auth_v2.169.0/go.mod b/auth_v2.169.0/go.mod new file mode 100644 index 0000000..a99b2b6 --- /dev/null +++ b/auth_v2.169.0/go.mod @@ -0,0 +1,163 @@ +module github.com/supabase/auth + +require ( + github.com/Masterminds/semver/v3 v3.1.1 // indirect + github.com/aaronarduino/goqrsvg v0.0.0-20220419053939-17e843f1dd40 + github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b + github.com/badoux/checkmail v0.0.0-20170203135005-d0a759655d62 + github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc + github.com/coreos/go-oidc/v3 v3.6.0 + github.com/didip/tollbooth/v5 v5.1.1 + github.com/gobuffalo/validate/v3 v3.3.3 // indirect + github.com/gobwas/glob v0.2.3 + github.com/gofrs/uuid v4.3.1+incompatible + github.com/jackc/pgconn v1.14.3 + github.com/jackc/pgerrcode v0.0.0-20201024163028-a0d42d470451 + github.com/jackc/pgproto3/v2 v2.3.3 // indirect + github.com/jmoiron/sqlx v1.3.5 + github.com/joho/godotenv v1.4.0 + github.com/kelseyhightower/envconfig v1.4.0 + github.com/microcosm-cc/bluemonday v1.0.26 // indirect + github.com/mitchellh/mapstructure v1.5.0 + github.com/mrjones/oauth v0.0.0-20190623134757-126b35219450 + github.com/pkg/errors v0.9.1 + github.com/pquerna/otp v1.4.0 + github.com/rs/cors v1.11.0 + github.com/sebest/xff v0.0.0-20160910043805-6c115e0ffa35 + github.com/sethvargo/go-password v0.2.0 + github.com/sirupsen/logrus v1.9.3 + github.com/spf13/cobra v1.6.1 + github.com/stretchr/testify v1.9.0 + golang.org/x/crypto v0.31.0 + golang.org/x/oauth2 v0.17.0 + gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df +) + +require ( + github.com/bits-and-blooms/bitset v1.10.0 // indirect + github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 // indirect + github.com/fxamacker/cbor/v2 v2.7.0 // indirect + github.com/go-jose/go-jose/v3 v3.0.3 // indirect + github.com/go-webauthn/x v0.1.12 // indirect + github.com/gobuffalo/nulls v0.4.2 // indirect + github.com/goccy/go-json v0.10.3 // indirect + github.com/google/go-tpm v0.9.1 // indirect + github.com/jackc/pgx/v4 v4.18.2 // indirect + github.com/lestrrat-go/blackmagic v1.0.2 // indirect + github.com/lestrrat-go/httpcc v1.0.1 // indirect + github.com/lestrrat-go/httprc v1.0.5 // indirect + github.com/lestrrat-go/iter v1.0.2 // indirect + github.com/lestrrat-go/option v1.0.1 // indirect + github.com/segmentio/asm v1.2.0 // indirect + github.com/x448/float16 v0.8.4 // indirect + github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f // indirect + github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect + golang.org/x/mod v0.17.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240227224415-6ceb2ff114de // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240401170217-c3f982113cda // indirect +) + +require ( + github.com/XSAM/otelsql v0.26.0 + github.com/bombsimon/logrusr/v3 v3.0.0 + go.opentelemetry.io/contrib/instrumentation/runtime v0.45.0 + go.opentelemetry.io/otel v1.26.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.19.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 + go.opentelemetry.io/otel/metric v1.26.0 + go.opentelemetry.io/otel/sdk v1.26.0 + go.opentelemetry.io/otel/sdk/metric v1.26.0 + go.opentelemetry.io/otel/trace v1.26.0 + gopkg.in/h2non/gock.v1 v1.1.2 +) + +require ( + github.com/bits-and-blooms/bloom/v3 v3.6.0 + github.com/crewjam/saml v0.4.14 + github.com/deepmap/oapi-codegen v1.12.4 + github.com/fatih/structs v1.1.0 + github.com/fsnotify/fsnotify v1.7.0 + github.com/go-chi/chi/v5 v5.0.12 + github.com/go-webauthn/webauthn v0.11.1 + github.com/gobuffalo/pop/v6 v6.1.1 + github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/lestrrat-go/jwx/v2 v2.1.0 + github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721 + github.com/supabase/hibp v0.0.0-20231124125943-d225752ae869 + github.com/xeipuuv/gojsonschema v1.2.0 + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 + go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.26.0 + go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.26.0 + go.opentelemetry.io/otel/exporters/prometheus v0.48.0 +) + +require ( + github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect + github.com/aymerick/douceur v0.2.0 // indirect + github.com/beevik/etree v1.1.0 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/cenkalti/backoff/v4 v4.3.0 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/crewjam/httperr v0.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fatih/color v1.13.0 // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/go-logr/logr v1.4.1 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-sql-driver/mysql v1.7.0 // indirect + github.com/gobuffalo/envy v1.10.2 // indirect + github.com/gobuffalo/fizz v1.14.4 // indirect + github.com/gobuffalo/flect v1.0.2 // indirect + github.com/gobuffalo/github_flavored_markdown v1.1.3 // indirect + github.com/gobuffalo/helpers v0.6.7 // indirect + github.com/gobuffalo/plush/v4 v4.1.18 // indirect + github.com/gobuffalo/tags/v3 v3.1.4 // indirect + github.com/golang-jwt/jwt/v4 v4.5.1 // indirect + github.com/golang/protobuf v1.5.4 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/gorilla/css v1.0.0 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1 // indirect + github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect + github.com/inconshreveable/mousetrap v1.0.1 // indirect + github.com/jackc/chunkreader/v2 v2.0.1 // indirect + github.com/jackc/pgio v1.0.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgtype v1.14.0 // indirect + github.com/jonboulle/clockwork v0.2.2 // indirect + github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect + github.com/luna-duclos/instrumentedsql v1.1.3 // indirect + github.com/mattermost/xml-roundtrip-validator v0.1.0 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.16 // indirect + github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect + github.com/patrickmn/go-cache v2.1.0+incompatible // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/prometheus/client_golang v1.19.0 + github.com/prometheus/client_model v0.6.1 // indirect + github.com/prometheus/common v0.48.0 // indirect + github.com/prometheus/procfs v0.12.0 // indirect + github.com/rogpeppe/go-internal v1.11.0 // indirect + github.com/russellhaering/goxmldsig v1.3.0 // indirect + github.com/sergi/go-diff v1.2.0 // indirect + github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d // indirect + github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/stretchr/objx v0.5.2 // indirect + go.opentelemetry.io/proto/otlp v1.2.0 // indirect + golang.org/x/exp v0.0.0-20230213192124-5e25df0256eb + golang.org/x/net v0.23.0 // indirect + golang.org/x/sync v0.10.0 + golang.org/x/sys v0.28.0 // indirect + golang.org/x/text v0.21.0 // indirect + golang.org/x/time v0.0.0-20220411224347-583f2d630306 // indirect + google.golang.org/appengine v1.6.8 // indirect + google.golang.org/grpc v1.63.2 // indirect + google.golang.org/protobuf v1.33.0 // indirect + gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +go 1.22.3 diff --git a/auth_v2.169.0/go.sum b/auth_v2.169.0/go.sum new file mode 100644 index 0000000..827144a --- /dev/null +++ b/auth_v2.169.0/go.sum @@ -0,0 +1,559 @@ +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= +github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= +github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk= +github.com/XSAM/otelsql v0.26.0 h1:UhAGVBD34Ctbh2aYcm/JAdL+6T6ybrP+YMWYkHqCdmo= +github.com/XSAM/otelsql v0.26.0/go.mod h1:5ciw61eMSh+RtTPN8spvPEPLJpAErZw8mFFPNfYiaxA= +github.com/aaronarduino/goqrsvg v0.0.0-20220419053939-17e843f1dd40 h1:uz4N2yHL4MF8vZX+36n+tcxeUf8D/gL4aJkyouhDw4A= +github.com/aaronarduino/goqrsvg v0.0.0-20220419053939-17e843f1dd40/go.mod h1:dytw+5qs+pdi61fO/S4OmXR7AuEq/HvNCuG03KxQHT4= +github.com/ajstarks/deck v0.0.0-20200831202436-30c9fc6549a9/go.mod h1:JynElWSGnm/4RlzPXRlREEwqTHAN3T56Bv2ITsFT3gY= +github.com/ajstarks/deck/generate v0.0.0-20210309230005-c3f852c02e19/go.mod h1:T13YZdzov6OU0A1+RfKZiZN9ca6VeKdBdyDV+BY97Tk= +github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b h1:slYM766cy2nI3BwyRiyQj/Ud48djTMtMebDqepE95rw= +github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b/go.mod h1:1KcenG0jGWcpt8ov532z81sp/kMMUG485J2InIOyADM= +github.com/apapsch/go-jsonmerge/v2 v2.0.0 h1:axGnT1gRIfimI7gJifB699GoE/oq+F2MU7Dml6nw9rQ= +github.com/apapsch/go-jsonmerge/v2 v2.0.0/go.mod h1:lvDnEdqiQrp0O42VQGgmlKpxL1AP2+08jFMw88y4klk= +github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= +github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= +github.com/badoux/checkmail v0.0.0-20170203135005-d0a759655d62 h1:vMqcPzLT1/mbYew0gM6EJy4/sCNy9lY9rmlFO+pPwhY= +github.com/badoux/checkmail v0.0.0-20170203135005-d0a759655d62/go.mod h1:r5ZalvRl3tXevRNJkwIB6DC4DD3DMjIlY9NEU1XGoaQ= +github.com/beevik/etree v1.1.0 h1:T0xke/WvNtMoCqgzPhkX2r4rjY3GDZFi+FjpRZY2Jbs= +github.com/beevik/etree v1.1.0/go.mod h1:r8Aw8JqVegEf0w2fDnATrX9VpkMcyFeM0FhwO62wh+A= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bits-and-blooms/bitset v1.10.0 h1:ePXTeiPEazB5+opbv5fr8umg2R/1NlzgDsyepwsSr88= +github.com/bits-and-blooms/bitset v1.10.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= +github.com/bits-and-blooms/bloom/v3 v3.6.0 h1:dTU0OVLJSoOhz9m68FTXMFfA39nR8U/nTCs1zb26mOI= +github.com/bits-and-blooms/bloom/v3 v3.6.0/go.mod h1:VKlUSvp0lFIYqxJjzdnSsZEw4iHb1kOL2tfHTgyJBHg= +github.com/bmatcuk/doublestar v1.1.1/go.mod h1:UD6OnuiIn0yFxxA2le/rnRU1G4RaI4UvFv1sNto9p6w= +github.com/bombsimon/logrusr/v3 v3.0.0 h1:tcAoLfuAhKP9npBxWzSdpsvKPQt1XV02nSf2lZA82TQ= +github.com/bombsimon/logrusr/v3 v3.0.0/go.mod h1:PksPPgSFEL2I52pla2glgCyyd2OqOHAnFF5E+g8Ixco= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= +github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= +github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/coreos/go-oidc/v3 v3.6.0 h1:AKVxfYw1Gmkn/w96z0DbT/B/xFnzTd3MkZvWLjF4n/o= +github.com/coreos/go-oidc/v3 v3.6.0/go.mod h1:ZpHUsHBucTUj6WOkrP4E20UPynbLZzhTQ1XKCXkxyPc= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/crewjam/httperr v0.2.0 h1:b2BfXR8U3AlIHwNeFFvZ+BV1LFvKLlzMjzaTnZMybNo= +github.com/crewjam/httperr v0.2.0/go.mod h1:Jlz+Sg/XqBQhyMjdDiC+GNNRzZTD7x39Gu3pglZ5oH4= +github.com/crewjam/saml v0.4.14 h1:g9FBNx62osKusnFzs3QTN5L9CVA/Egfgm+stJShzw/c= +github.com/crewjam/saml v0.4.14/go.mod h1:UVSZCf18jJkk6GpWNVqcyQJMD5HsRugBPf4I1nl2mME= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 h1:rpfIENRNNilwHwZeG5+P150SMrnNEcHYvcCuK6dPZSg= +github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= +github.com/deepmap/oapi-codegen v1.12.4 h1:pPmn6qI9MuOtCz82WY2Xaw46EQjgvxednXXrP7g5Q2s= +github.com/deepmap/oapi-codegen v1.12.4/go.mod h1:3lgHGMu6myQ2vqbbTXH2H1o4eXFTGnFiDaOaKKl5yas= +github.com/didip/tollbooth/v5 v5.1.1 h1:QpKFg56jsbNuQ6FFj++Z1gn2fbBsvAc1ZPLUaDOYW5k= +github.com/didip/tollbooth/v5 v5.1.1/go.mod h1:d9rzwOULswrD3YIrAQmP3bfjxab32Df4IaO6+D25l9g= +github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= +github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= +github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo= +github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= +github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= +github.com/go-chi/chi/v5 v5.0.12 h1:9euLV5sTrTNTRUU9POmDUvfxyj6LAABLUcEWO+JJb4s= +github.com/go-chi/chi/v5 v5.0.12/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/go-jose/go-jose/v3 v3.0.3 h1:fFKWeig/irsp7XD2zBxvnmA/XaRWp5V3CBsZXJF7G7k= +github.com/go-jose/go-jose/v3 v3.0.3/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ= +github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= +github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.1 h1:pKouT5E8xu9zeFC39JXRDukb6JFQPXM5p5I91188VAQ= +github.com/go-logr/logr v1.4.1/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/go-webauthn/webauthn v0.11.1 h1:5G/+dg91/VcaJHTtJUfwIlNJkLwbJCcnUc4W8VtkpzA= +github.com/go-webauthn/webauthn v0.11.1/go.mod h1:YXRm1WG0OtUyDFaVAgB5KG7kVqW+6dYCJ7FTQH4SxEE= +github.com/go-webauthn/x v0.1.12 h1:RjQ5cvApzyU/xLCiP+rub0PE4HBZsLggbxGR5ZpUf/A= +github.com/go-webauthn/x v0.1.12/go.mod h1:XlRcGkNH8PT45TfeJYc6gqpOtiOendHhVmnOxh+5yHs= +github.com/gobuffalo/attrs v1.0.3/go.mod h1:KvDJCE0avbufqS0Bw3UV7RQynESY0jjod+572ctX4t8= +github.com/gobuffalo/envy v1.10.2 h1:EIi03p9c3yeuRCFPOKcSfajzkLb3hrRjEpHGI8I2Wo4= +github.com/gobuffalo/envy v1.10.2/go.mod h1:qGAGwdvDsaEtPhfBzb3o0SfDea8ByGn9j8bKmVft9z8= +github.com/gobuffalo/fizz v1.14.4 h1:8uume7joF6niTNWN582IQ2jhGTUoa9g1fiV/tIoGdBs= +github.com/gobuffalo/fizz v1.14.4/go.mod h1:9/2fGNXNeIFOXEEgTPJwiK63e44RjG+Nc4hfMm1ArGM= +github.com/gobuffalo/flect v0.3.0/go.mod h1:5pf3aGnsvqvCj50AVni7mJJF8ICxGZ8HomberC3pXLE= +github.com/gobuffalo/flect v1.0.0/go.mod h1:l9V6xSb4BlXwsxEMj3FVEub2nkdQjWhPvD8XTTlHPQc= +github.com/gobuffalo/flect v1.0.2 h1:eqjPGSo2WmjgY2XlpGwo2NXgL3RucAKo4k4qQMNA5sA= +github.com/gobuffalo/flect v1.0.2/go.mod h1:A5msMlrHtLqh9umBSnvabjsMrCcCpAyzglnDvkbYKHs= +github.com/gobuffalo/genny/v2 v2.1.0/go.mod h1:4yoTNk4bYuP3BMM6uQKYPvtP6WsXFGm2w2EFYZdRls8= +github.com/gobuffalo/github_flavored_markdown v1.1.3 h1:rSMPtx9ePkFB22vJ+dH+m/EUBS8doQ3S8LeEXcdwZHk= +github.com/gobuffalo/github_flavored_markdown v1.1.3/go.mod h1:IzgO5xS6hqkDmUh91BW/+Qxo/qYnvfzoz3A7uLkg77I= +github.com/gobuffalo/helpers v0.6.7 h1:C9CedoRSfgWg2ZoIkVXgjI5kgmSpL34Z3qdnzpfNVd8= +github.com/gobuffalo/helpers v0.6.7/go.mod h1:j0u1iC1VqlCaJEEVkZN8Ia3TEzfj/zoXANqyJExTMTA= +github.com/gobuffalo/logger v1.0.7/go.mod h1:u40u6Bq3VVvaMcy5sRBclD8SXhBYPS0Qk95ubt+1xJM= +github.com/gobuffalo/nulls v0.4.2 h1:GAqBR29R3oPY+WCC7JL9KKk9erchaNuV6unsOSZGQkw= +github.com/gobuffalo/nulls v0.4.2/go.mod h1:EElw2zmBYafU2R9W4Ii1ByIj177wA/pc0JdjtD0EsH8= +github.com/gobuffalo/packd v1.0.2/go.mod h1:sUc61tDqGMXON80zpKGp92lDb86Km28jfvX7IAyxFT8= +github.com/gobuffalo/plush/v4 v4.1.16/go.mod h1:6t7swVsarJ8qSLw1qyAH/KbrcSTwdun2ASEQkOznakg= +github.com/gobuffalo/plush/v4 v4.1.18 h1:bnPjdMTEUQHqj9TNX2Ck3mxEXYZa+0nrFMNM07kpX9g= +github.com/gobuffalo/plush/v4 v4.1.18/go.mod h1:xi2tJIhFI4UdzIL8sxZtzGYOd2xbBpcFbLZlIPGGZhU= +github.com/gobuffalo/pop/v6 v6.1.1 h1:eUDBaZcb0gYrmFnKwpuTEUA7t5ZHqNfvS4POqJYXDZY= +github.com/gobuffalo/pop/v6 v6.1.1/go.mod h1:1n7jAmI1i7fxuXPZjZb0VBPQDbksRtCoFnrDV5IsvaI= +github.com/gobuffalo/tags/v3 v3.1.4 h1:X/ydLLPhgXV4h04Hp2xlbI2oc5MDaa7eub6zw8oHjsM= +github.com/gobuffalo/tags/v3 v3.1.4/go.mod h1:ArRNo3ErlHO8BtdA0REaZxijuWnWzF6PUXngmMXd2I0= +github.com/gobuffalo/validate/v3 v3.3.3 h1:o7wkIGSvZBYBd6ChQoLxkz2y1pfmhbI4jNJYh6PuNJ4= +github.com/gobuffalo/validate/v3 v3.3.3/go.mod h1:YC7FsbJ/9hW/VjQdmXPvFqvRis4vrRYFxr69WiNZw6g= +github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= +github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= +github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= +github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/gofrs/uuid v4.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/gofrs/uuid v4.3.1+incompatible h1:0/KbAdpx3UXAx1kEOWHJeOkpbgRFGHVgv+CFIY7dBJI= +github.com/gofrs/uuid v4.3.1+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang-jwt/jwt/v4 v4.5.1 h1:JdqV9zKUdtaa9gdPlywC3aeoEsR681PlKC+4F5gQgeo= +github.com/golang-jwt/jwt/v4 v4.5.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= +github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-tpm v0.9.1 h1:0pGc4X//bAlmZzMKf8iz6IsDo1nYTbYJ6FZN/rg4zdM= +github.com/google/go-tpm v0.9.1/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY= +github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1 h1:/c3QmbOGMGTOumP2iT/rCwB7b0QDGLKzqOmktBjT+Is= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1/go.mod h1:5SN9VR2LTsRFsrEC6FHgRbTWrTHu6tqPeKxEQv15giM= +github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= +github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI= +github.com/inconshreveable/mousetrap v1.0.1 h1:U3uMjPSQEBMNp1lFxmllqCPM6P5u/Xq7Pgzkat/bFNc= +github.com/inconshreveable/mousetrap v1.0.1/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= +github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= +github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= +github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= +github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= +github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= +github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= +github.com/jackc/pgconn v1.13.0/go.mod h1:AnowpAqO4CMIIJNZl2VJp+KrkAZciAkhEl0W0JIobpI= +github.com/jackc/pgconn v1.14.3 h1:bVoTr12EGANZz66nZPkMInAV/KHD2TxH9npjXXgiB3w= +github.com/jackc/pgconn v1.14.3/go.mod h1:RZbme4uasqzybK2RK5c65VsHxoyaml09lx3tXOcO/VM= +github.com/jackc/pgerrcode v0.0.0-20201024163028-a0d42d470451 h1:WAvSpGf7MsFuzAtK4Vk7R4EVe+liW4x83r4oWu0WHKw= +github.com/jackc/pgerrcode v0.0.0-20201024163028-a0d42d470451/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds= +github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= +github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= +github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= +github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.3.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.3.3 h1:1HLSx5H+tXR9pW3in3zaztoEwQYRC9SQaYUHjTSUOag= +github.com/jackc/pgproto3/v2 v2.3.3/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= +github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= +github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= +github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM= +github.com/jackc/pgtype v1.12.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= +github.com/jackc/pgtype v1.14.0 h1:y+xUdabmyMkJLyApYuPj38mW+aAIqCe5uuBB51rH3Vw= +github.com/jackc/pgtype v1.14.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= +github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= +github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= +github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= +github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= +github.com/jackc/pgx/v4 v4.17.2/go.mod h1:lcxIZN44yMIrWI78a5CpucdD14hX0SBDbNRvjDBItsw= +github.com/jackc/pgx/v4 v4.18.2 h1:xVpYkNR5pk5bMCZGfClbO962UIqVABcAGt7ha1s/FeU= +github.com/jackc/pgx/v4 v4.18.2/go.mod h1:Ey4Oru5tH5sB6tV7hDmfWFahwF15Eb7DNXlRKx2CkVw= +github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= +github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= +github.com/joho/godotenv v1.4.0 h1:3l4+N6zfMWnkbPEXKng2o2/MR5mSwTrBih4ZEkkz1lg= +github.com/joho/godotenv v1.4.0/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/jonboulle/clockwork v0.2.2 h1:UOGuzwb1PwsrDAObMuhUnj0p5ULPj8V/xJ7Kx9qUBdQ= +github.com/jonboulle/clockwork v0.2.2/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= +github.com/juju/gnuflag v0.0.0-20171113085948-2ce1bb71843d/go.mod h1:2PavIy+JPciBPrBUjwbNvtwB6RQlve+hkpll6QSNmOE= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= +github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8= +github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k= +github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU= +github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE= +github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E= +github.com/lestrrat-go/httprc v1.0.5 h1:bsTfiH8xaKOJPrg1R+E3iE/AWZr/x0Phj9PBTG/OLUk= +github.com/lestrrat-go/httprc v1.0.5/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo= +github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI= +github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4= +github.com/lestrrat-go/jwx/v2 v2.1.0 h1:0zs7Ya6+39qoit7gwAf+cYm1zzgS3fceIdo7RmQ5lkw= +github.com/lestrrat-go/jwx/v2 v2.1.0/go.mod h1:Xpw9QIaUGiIUD1Wx0NcY1sIHwFf8lDuZn/cmxtXYRys= +github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU= +github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= +github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= +github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/luna-duclos/instrumentedsql v1.1.3 h1:t7mvC0z1jUt5A0UQ6I/0H31ryymuQRnJcWCiqV3lSAA= +github.com/luna-duclos/instrumentedsql v1.1.3/go.mod h1:9J1njvFds+zN7y85EDhN9XNQLANWwZt2ULeIC8yMNYs= +github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU= +github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To= +github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= +github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= +github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/microcosm-cc/bluemonday v1.0.20/go.mod h1:yfBmMi8mxvaZut3Yytv+jTXRY8mxyjJ0/kQBTElld50= +github.com/microcosm-cc/bluemonday v1.0.26 h1:xbqSvqzQMeEHCqMi64VAs4d8uy6Mequs3rQ0k/Khz58= +github.com/microcosm-cc/bluemonday v1.0.26/go.mod h1:JyzOCs9gkyQyjs+6h10UEVSe02CGwkhd72Xdqh78TWs= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mrjones/oauth v0.0.0-20190623134757-126b35219450 h1:j2kD3MT1z4PXCiUllUJF9mWUESr9TWKS7iEKsQ/IipM= +github.com/mrjones/oauth v0.0.0-20190623134757-126b35219450/go.mod h1:skjdDftzkFALcuGzYSklqYd8gvat6F1gZJ4YPVbkZpM= +github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32 h1:W6apQkHrMkS0Muv8G/TipAy/FJl/rCYT0+EuS8+Z0z4= +github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uYLpLIr5fm8diHn0JbqRycJi6w0Ms= +github.com/patrickmn/go-cache v0.0.0-20170418232947-7ac151875ffb/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pquerna/otp v1.4.0 h1:wZvl1TIVxKRThZIBiwOOHOGP/1+nZyWBil9Y2XNEDzg= +github.com/pquerna/otp v1.4.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg= +github.com/prometheus/client_golang v1.19.0 h1:ygXvpU1AoN1MhdzckN+PyD9QJOSD4x7kmXYlnfbA6JU= +github.com/prometheus/client_golang v1.19.0/go.mod h1:ZRM9uEAypZakd+q/x7+gmsvXdURP+DABIEIjnmDdp+k= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= +github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE= +github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc= +github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= +github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= +github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/rs/cors v1.11.0 h1:0B9GE/r9Bc2UxRMMtymBkHTenPkHDv0CW4Y98GBY+po= +github.com/rs/cors v1.11.0/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= +github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= +github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= +github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= +github.com/russellhaering/goxmldsig v1.3.0 h1:DllIWUgMy0cRUMfGiASiYEa35nsieyD3cigIwLonTPM= +github.com/russellhaering/goxmldsig v1.3.0/go.mod h1:gM4MDENBQf7M+V824SGfyIUVFWydB7n0KkEubVJl+Tw= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/sebest/xff v0.0.0-20160910043805-6c115e0ffa35 h1:eajwn6K3weW5cd1ZXLu2sJ4pvwlBiCWY4uDejOr73gM= +github.com/sebest/xff v0.0.0-20160910043805-6c115e0ffa35/go.mod h1:wozgYq9WEBQBaIJe4YZ0qTSFAMxmcwBhQH0fO0R34Z0= +github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= +github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= +github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ= +github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= +github.com/sethvargo/go-password v0.2.0 h1:BTDl4CC/gjf/axHMaDQtw507ogrXLci6XRiLc7i/UHI= +github.com/sethvargo/go-password v0.2.0/go.mod h1:Ym4Mr9JXLBycr02MFuVQ/0JHidNetSgbzutTr3zsYXE= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= +github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= +github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d h1:yKm7XZV6j9Ev6lojP2XaIshpT4ymkqhMeSghO5Ps00E= +github.com/sourcegraph/annotate v0.0.0-20160123013949-f4cad6c6324d/go.mod h1:UdhH50NIW0fCiwBSr0co2m7BnFLdv4fQTgdqdJTHFeE= +github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e h1:qpG93cPwA5f7s/ZPBJnGOYQNK/vKsaDaseuKT5Asee8= +github.com/sourcegraph/syntaxhighlight v0.0.0-20170531221838-bd320f5d308e/go.mod h1:HuIsMU8RRBOtsCgI77wP899iHVBQpCmg4ErYMZB+2IA= +github.com/spf13/cobra v1.6.1 h1:o94oiPyS4KD1mPy2fmcYYHHfCxLqYjJOhGsCHFZtEzA= +github.com/spf13/cobra v1.6.1/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUqzrY= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0= +github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721 h1:HTsFo0buahHfjuVUTPDdJRBkfjExkRM1LUBy6crQ7lc= +github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721/go.mod h1:L1MQhA6x4dn9r007T033lsaZMv9EmBAdXyU/+EF40fo= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/supabase/hibp v0.0.0-20231124125943-d225752ae869 h1:VDuRtwen5Z7QQ5ctuHUse4wAv/JozkKZkdic5vUV4Lg= +github.com/supabase/hibp v0.0.0-20231124125943-d225752ae869/go.mod h1:eHX5nlSMSnyPjUrbYzeqrA8snCe2SKyfizKjU3dkfOw= +github.com/twmb/murmur3 v1.1.6 h1:mqrRot1BRxm+Yct+vavLMou2/iJt0tNVTTC0QoIjaZg= +github.com/twmb/murmur3 v1.1.6/go.mod h1:Qq/R7NUyOfr65zD+6Q5IHKsJLwP7exErjN6lyyq3OSQ= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f h1:J9EGpcZtP0E/raorCMxlFGSTBrsSlaDGf3jU/qvAE2c= +github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= +github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= +github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= +github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc= +go.opentelemetry.io/contrib/instrumentation/runtime v0.45.0 h1:2JydY5UiDpqvj2p7sO9bgHuhTy4hgTZ0ymehdq/Ob0Q= +go.opentelemetry.io/contrib/instrumentation/runtime v0.45.0/go.mod h1:ch3a5QxOqVWxas4CzjCFFOOQe+7HgAXC/N1oVxS9DK4= +go.opentelemetry.io/otel v1.26.0 h1:LQwgL5s/1W7YiiRwxf03QGnWLb2HW4pLiAhaA5cZXBs= +go.opentelemetry.io/otel v1.26.0/go.mod h1:UmLkJHUAidDval2EICqBMbnAd0/m2vmpf/dAM+fvFs4= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.26.0 h1:+hm+I+KigBy3M24/h1p/NHkUx/evbLH0PNcjpMyCHc4= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.26.0/go.mod h1:NjC8142mLvvNT6biDpaMjyz78kyEHIwAJlSX0N9P5KI= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.26.0 h1:HGZWGmCVRCVyAs2GQaiHQPbDHo+ObFWeUEOd+zDnp64= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.26.0/go.mod h1:SaH+v38LSCHddyk7RGlU9uZyQoRrKao6IBnJw6Kbn+c= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0/go.mod h1:IPtUMKL4O3tH5y+iXVyAXqpAwMuzC1IrxVS81rummfE= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.19.0 h1:3d+S281UTjM+AbF31XSOYn1qXn3BgIdWl8HNEpx08Jk= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.19.0/go.mod h1:0+KuTDyKL4gjKCF75pHOX4wuzYDUZYfAQdSu43o+Z2I= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU= +go.opentelemetry.io/otel/exporters/prometheus v0.48.0 h1:sBQe3VNGUjY9IKWQC6z2lNqa5iGbDSxhs60ABwK4y0s= +go.opentelemetry.io/otel/exporters/prometheus v0.48.0/go.mod h1:DtrbMzoZWwQHyrQmCfLam5DZbnmorsGbOtTbYHycU5o= +go.opentelemetry.io/otel/metric v1.26.0 h1:7S39CLuY5Jgg9CrnA9HHiEjGMF/X2VHvoXGgSllRz30= +go.opentelemetry.io/otel/metric v1.26.0/go.mod h1:SY+rHOI4cEawI9a7N1A4nIg/nTQXe1ccCNWYOJUrpX4= +go.opentelemetry.io/otel/sdk v1.26.0 h1:Y7bumHf5tAiDlRYFmGqetNcLaVUZmh4iYfmGxtmz7F8= +go.opentelemetry.io/otel/sdk v1.26.0/go.mod h1:0p8MXpqLeJ0pzcszQQN4F0S5FVjBLgypeGSngLsmirs= +go.opentelemetry.io/otel/sdk/metric v1.26.0 h1:cWSks5tfriHPdWFnl+qpX3P681aAYqlZHcAyHw5aU9Y= +go.opentelemetry.io/otel/sdk/metric v1.26.0/go.mod h1:ClMFFknnThJCksebJwz7KIyEDHO+nTB6gK8obLy8RyE= +go.opentelemetry.io/otel/trace v1.26.0 h1:1ieeAUb4y0TE26jUFrCIXKpTuVK7uJGN9/Z/2LP5sQA= +go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZuWCBllGV2U2y0= +go.opentelemetry.io/proto/otlp v1.2.0 h1:pVeZGk7nXDC9O2hncA6nHldxEjm6LByfA2aN8IOkz94= +go.opentelemetry.io/proto/otlp v1.2.0/go.mod h1:gGpR8txAl5M03pDhMC79G6SdqNV26naRm/KDsgaHD8A= +go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A= +go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4= +go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= +go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= +go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/exp v0.0.0-20230213192124-5e25df0256eb h1:PaBZQdo+iSDyHT053FjUCgZQ/9uqVwPOcl7KSWhKn6w= +golang.org/x/exp v0.0.0-20230213192124-5e25df0256eb/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/net v0.0.0-20161007143504-f4b625ec9b21/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.0.0-20221002022538-bcab6841153b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= +golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/oauth2 v0.17.0 h1:6m3ZPmLEFdVxKKWnKq4VqZ60gutO35zm+zrAHVmHyDQ= +golang.org/x/oauth2 v0.17.0/go.mod h1:OzPDGQiuQMguemayvdylqddI7qcD9lnSDb+1FiwQ5HA= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220929204114-8fcdb60fdcc0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.0.0-20220722155259-a9ba230a4035/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +golang.org/x/time v0.0.0-20160926182426-711ca1cb8763/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20220411224347-583f2d630306 h1:+gHMid33q6pen7kv9xvT+JRinntgeXO2AeZVd0AWD3w= +golang.org/x/time v0.0.0-20220411224347-583f2d630306/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM= +google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds= +google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de h1:F6qOa9AZTYJXOUEr4jDysRDLrm4PHePlge4v4TGAlxY= +google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de/go.mod h1:VUhTRKeHn9wwcdrk73nvdC9gF178Tzhmt/qyaFcPLSo= +google.golang.org/genproto/googleapis/api v0.0.0-20240227224415-6ceb2ff114de h1:jFNzHPIeuzhdRwVhbZdiym9q0ory/xY3sA+v2wPg8I0= +google.golang.org/genproto/googleapis/api v0.0.0-20240227224415-6ceb2ff114de/go.mod h1:5iCWqnniDlqZHrd3neWVTOwvh/v6s3232omMecelax8= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240401170217-c3f982113cda h1:LI5DOvAxUPMv/50agcLLoo+AdWc1irS9Rzz4vPuD1V4= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240401170217-c3f982113cda/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY= +google.golang.org/grpc v1.63.2 h1:MUeiw1B2maTVZthpU5xvASfTh3LDbxHd6IJ6QQVU+xM= +google.golang.org/grpc v1.63.2/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk= +gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df h1:n7WqCuqOuCbNr617RXOY0AWRXxgwEyPp2z+p0+hgMuE= +gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df/go.mod h1:LRQQ+SO6ZHR7tOkpBDuZnXENFzX8qRjMDMyPD6BRkCw= +gopkg.in/h2non/gock.v1 v1.1.2 h1:jBbHXgGBK/AoPVfJh5x4r/WxIrElvbLel8TCZkkZJoY= +gopkg.in/h2non/gock.v1 v1.1.2/go.mod h1:n7UGz/ckNChHiK05rDoiC4MYSunEC/lyaUm2WWaDva0= +gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= +gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= +honnef.co/go/tools v0.1.3/go.mod h1:NgwopIslSNH47DimFoV78dnkksY2EFtX0ajyb3K/las= diff --git a/auth_v2.169.0/hack/coverage.sh b/auth_v2.169.0/hack/coverage.sh new file mode 100644 index 0000000..5510196 --- /dev/null +++ b/auth_v2.169.0/hack/coverage.sh @@ -0,0 +1,21 @@ +FAIL=false + +for PKG in "crypto" +do + UNCOVERED_FUNCS=$(go tool cover -func=coverage.out | grep "^github.com/supabase/auth/internal/$PKG/" | grep -v '100.0%$') + UNCOVERED_FUNCS_COUNT=$(echo "$UNCOVERED_FUNCS" | wc -l) + + if [ "$UNCOVERED_FUNCS_COUNT" -gt 1 ] # wc -l counts +1 line + then + echo "Package $PKG not covered 100% with tests. $UNCOVERED_FUNCS_COUNT functions need more tests. This is mandatory." + echo "$UNCOVERED_FUNCS" + FAIL=true + fi +done + +if [ "$FAIL" = "true" ] +then + exit 1 +else + exit 0 +fi diff --git a/auth_v2.169.0/hack/database.yml b/auth_v2.169.0/hack/database.yml new file mode 100644 index 0000000..8614ce4 --- /dev/null +++ b/auth_v2.169.0/hack/database.yml @@ -0,0 +1,15 @@ +postgres: + dialect: "postgres" + database: "postgres" + host: {{ envOr "POSTGRES_HOST" "127.0.0.1" }} + port: {{ envOr "POSTGRES_PORT" "5432" }} + user: {{ envOr "POSTGRES_USER" "postgres" }} + password: {{ envOr "POSTGRES_PASSWORD" "root" }} + +test: + dialect: "postgres" + database: "postgres" + host: {{ envOr "POSTGRES_HOST" "127.0.0.1" }} + port: {{ envOr "POSTGRES_PORT" "5432" }} + user: {{ envOr "POSTGRES_USER" "postgres" }} + password: {{ envOr "POSTGRES_PASSWORD" "root" }} diff --git a/auth_v2.169.0/hack/init_postgres.sql b/auth_v2.169.0/hack/init_postgres.sql new file mode 100644 index 0000000..d1ef709 --- /dev/null +++ b/auth_v2.169.0/hack/init_postgres.sql @@ -0,0 +1,7 @@ +CREATE USER supabase_admin LOGIN CREATEROLE CREATEDB REPLICATION BYPASSRLS; + +-- Supabase super admin +CREATE USER supabase_auth_admin NOINHERIT CREATEROLE LOGIN NOREPLICATION PASSWORD 'root'; +CREATE SCHEMA IF NOT EXISTS auth AUTHORIZATION supabase_auth_admin; +GRANT CREATE ON DATABASE postgres TO supabase_auth_admin; +ALTER USER supabase_auth_admin SET search_path = 'auth'; diff --git a/auth_v2.169.0/hack/migrate.sh b/auth_v2.169.0/hack/migrate.sh new file mode 100644 index 0000000..2d1f0e5 --- /dev/null +++ b/auth_v2.169.0/hack/migrate.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +DB_ENV=$1 + +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +DATABASE="$DIR/database.yml" + +export GOTRUE_DB_DRIVER="postgres" +export GOTRUE_DB_DATABASE_URL="postgres://supabase_auth_admin:root@localhost:5432/$DB_ENV" +export GOTRUE_DB_MIGRATIONS_PATH=$DIR/../migrations + +go run main.go migrate -c $DIR/test.env diff --git a/auth_v2.169.0/hack/postgresd.sh b/auth_v2.169.0/hack/postgresd.sh new file mode 100644 index 0000000..c4b6a58 --- /dev/null +++ b/auth_v2.169.0/hack/postgresd.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash + +docker rm -f gotrue_postgresql >/dev/null 2>/dev/null || true + +docker volume inspect postgres_data 2>/dev/null >/dev/null || docker volume create --name postgres_data >/dev/null + +docker run --name gotrue_postgresql \ + -p 5432:5432 \ + -e POSTGRES_USER=postgres \ + -e POSTGRES_PASSWORD=root \ + -e POSTGRES_DB=postgres \ + --volume postgres_data:/var/lib/postgresql/data \ + --volume "$(pwd)"/hack/init_postgres.sql:/docker-entrypoint-initdb.d/init.sql \ + -d postgres:15 diff --git a/auth_v2.169.0/hack/test.env b/auth_v2.169.0/hack/test.env new file mode 100644 index 0000000..35e4b61 --- /dev/null +++ b/auth_v2.169.0/hack/test.env @@ -0,0 +1,128 @@ +GOTRUE_JWT_SECRET=testsecret +GOTRUE_JWT_EXP=3600 +GOTRUE_JWT_AUD="authenticated" +GOTRUE_JWT_ADMIN_ROLES="supabase_admin,service_role" +GOTRUE_JWT_DEFAULT_GROUP_NAME="authenticated" +GOTRUE_DB_DRIVER=postgres +DB_NAMESPACE="auth" +GOTRUE_DB_AUTOMIGRATE=true +DATABASE_URL="postgres://supabase_auth_admin:root@localhost:5432/postgres" +GOTRUE_API_HOST=localhost +PORT=9999 +API_EXTERNAL_URL="http://localhost:9999" +GOTRUE_LOG_SQL=none +GOTRUE_LOG_LEVEL=warn +GOTRUE_SITE_URL=https://example.netlify.com +GOTRUE_URI_ALLOW_LIST="http://localhost:3000" +GOTRUE_OPERATOR_TOKEN=foobar +GOTRUE_EXTERNAL_APPLE_ENABLED=true +GOTRUE_EXTERNAL_APPLE_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_APPLE_SECRET=testsecret +GOTRUE_EXTERNAL_APPLE_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_AZURE_ENABLED=true +GOTRUE_EXTERNAL_AZURE_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_AZURE_SECRET=testsecret +GOTRUE_EXTERNAL_AZURE_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_BITBUCKET_ENABLED=true +GOTRUE_EXTERNAL_BITBUCKET_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_BITBUCKET_SECRET=testsecret +GOTRUE_EXTERNAL_BITBUCKET_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_DISCORD_ENABLED=true +GOTRUE_EXTERNAL_DISCORD_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_DISCORD_SECRET=testsecret +GOTRUE_EXTERNAL_DISCORD_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_FACEBOOK_ENABLED=true +GOTRUE_EXTERNAL_FACEBOOK_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_FACEBOOK_SECRET=testsecret +GOTRUE_EXTERNAL_FACEBOOK_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_FLY_ENABLED=true +GOTRUE_EXTERNAL_FLY_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_FLY_SECRET=testsecret +GOTRUE_EXTERNAL_FLY_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_FIGMA_ENABLED=true +GOTRUE_EXTERNAL_FIGMA_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_FIGMA_SECRET=testsecret +GOTRUE_EXTERNAL_FIGMA_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_GITHUB_ENABLED=true +GOTRUE_EXTERNAL_GITHUB_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_GITHUB_SECRET=testsecret +GOTRUE_EXTERNAL_GITHUB_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_KAKAO_ENABLED=true +GOTRUE_EXTERNAL_KAKAO_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_KAKAO_SECRET=testsecret +GOTRUE_EXTERNAL_KAKAO_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_KEYCLOAK_ENABLED=true +GOTRUE_EXTERNAL_KEYCLOAK_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_KEYCLOAK_SECRET=testsecret +GOTRUE_EXTERNAL_KEYCLOAK_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_KEYCLOAK_URL=https://keycloak.example.com/auth/realms/myrealm +GOTRUE_EXTERNAL_LINKEDIN_ENABLED=true +GOTRUE_EXTERNAL_LINKEDIN_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_LINKEDIN_SECRET=testsecret +GOTRUE_EXTERNAL_LINKEDIN_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_LINKEDIN_OIDC_ENABLED=true +GOTRUE_EXTERNAL_LINKEDIN_OIDC_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_LINKEDIN_OIDC_SECRET=testsecret +GOTRUE_EXTERNAL_LINKEDIN_OIDC_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_GITLAB_ENABLED=true +GOTRUE_EXTERNAL_GITLAB_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_GITLAB_SECRET=testsecret +GOTRUE_EXTERNAL_GITLAB_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_GOOGLE_ENABLED=true +GOTRUE_EXTERNAL_GOOGLE_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_GOOGLE_SECRET=testsecret +GOTRUE_EXTERNAL_GOOGLE_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_NOTION_ENABLED=true +GOTRUE_EXTERNAL_NOTION_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_NOTION_SECRET=testsecret +GOTRUE_EXTERNAL_NOTION_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_SPOTIFY_ENABLED=true +GOTRUE_EXTERNAL_SPOTIFY_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_SPOTIFY_SECRET=testsecret +GOTRUE_EXTERNAL_SPOTIFY_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_SLACK_ENABLED=true +GOTRUE_EXTERNAL_SLACK_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_SLACK_SECRET=testsecret +GOTRUE_EXTERNAL_SLACK_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_SLACK_OIDC_ENABLED=true +GOTRUE_EXTERNAL_SLACK_OIDC_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_SLACK_OIDC_SECRET=testsecret +GOTRUE_EXTERNAL_SLACK_OIDC_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_WORKOS_ENABLED=true +GOTRUE_EXTERNAL_WORKOS_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_WORKOS_SECRET=testsecret +GOTRUE_EXTERNAL_WORKOS_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_TWITCH_ENABLED=true +GOTRUE_EXTERNAL_TWITCH_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_TWITCH_SECRET=testsecret +GOTRUE_EXTERNAL_TWITCH_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_TWITTER_ENABLED=true +GOTRUE_EXTERNAL_TWITTER_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_TWITTER_SECRET=testsecret +GOTRUE_EXTERNAL_TWITTER_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_ZOOM_ENABLED=true +GOTRUE_EXTERNAL_ZOOM_CLIENT_ID=testclientid +GOTRUE_EXTERNAL_ZOOM_SECRET=testsecret +GOTRUE_EXTERNAL_ZOOM_REDIRECT_URI=https://identity.services.netlify.com/callback +GOTRUE_EXTERNAL_FLOW_STATE_EXPIRY_DURATION="300s" +GOTRUE_RATE_LIMIT_VERIFY="100000" +GOTRUE_RATE_LIMIT_TOKEN_REFRESH="30" +GOTRUE_RATE_LIMIT_ANONYMOUS_USERS="5" +GOTRUE_RATE_LIMIT_HEADER="My-Custom-Header" +GOTRUE_TRACING_ENABLED=true +GOTRUE_TRACING_EXPORTER=default +GOTRUE_TRACING_HOST=127.0.0.1 +GOTRUE_TRACING_PORT=8126 +GOTRUE_TRACING_TAGS="env:test" +GOTRUE_SECURITY_CAPTCHA_ENABLED="false" +GOTRUE_SECURITY_CAPTCHA_PROVIDER="hcaptcha" +GOTRUE_SECURITY_CAPTCHA_SECRET="0x0000000000000000000000000000000000000000" +GOTRUE_SECURITY_CAPTCHA_TIMEOUT="10s" +GOTRUE_SAML_ENABLED="true" +GOTRUE_SAML_PRIVATE_KEY="MIIEowIBAAKCAQEAszrVveMQcSsa0Y+zN1ZFb19cRS0jn4UgIHTprW2tVBmO2PABzjY3XFCfx6vPirMAPWBYpsKmXrvm1tr0A6DZYmA8YmJd937VUQ67fa6DMyppBYTjNgGEkEhmKuszvF3MARsIKCGtZqUrmS7UG4404wYxVppnr2EYm3RGtHlkYsXu20MBqSDXP47bQP+PkJqC3BuNGk3xt5UHl2FSFpTHelkI6lBynw16B+lUT1F96SERNDaMqi/TRsZdGe5mB/29ngC/QBMpEbRBLNRir5iUevKS7Pn4aph9Qjaxx/97siktK210FJT23KjHpgcUfjoQ6BgPBTLtEeQdRyDuc/CgfwIDAQABAoIBAGYDWOEpupQPSsZ4mjMnAYJwrp4ZISuMpEqVAORbhspVeb70bLKonT4IDcmiexCg7cQBcLQKGpPVM4CbQ0RFazXZPMVq470ZDeWDEyhoCfk3bGtdxc1Zc9CDxNMs6FeQs6r1beEZug6weG5J/yRn/qYxQife3qEuDMl+lzfl2EN3HYVOSnBmdt50dxRuX26iW3nqqbMRqYn9OHuJ1LvRRfYeyVKqgC5vgt/6Tf7DAJwGe0dD7q08byHV8DBZ0pnMVU0bYpf1GTgMibgjnLjK//EVWafFHtN+RXcjzGmyJrk3+7ZyPUpzpDjO21kpzUQLrpEkkBRnmg6bwHnSrBr8avECgYEA3pq1PTCAOuLQoIm1CWR9/dhkbJQiKTJevlWV8slXQLR50P0WvI2RdFuSxlWmA4xZej8s4e7iD3MYye6SBsQHygOVGc4efvvEZV8/XTlDdyj7iLVGhnEmu2r7AFKzy8cOvXx0QcLg+zNd7vxZv/8D3Qj9Jje2LjLHKM5n/dZ3RzUCgYEAzh5Lo2anc4WN8faLGt7rPkGQF+7/18ImQE11joHWa3LzAEy7FbeOGpE/vhOv5umq5M/KlWFIRahMEQv4RusieHWI19ZLIP+JwQFxWxS+cPp3xOiGcquSAZnlyVSxZ//dlVgaZq2o2MfrxECcovRlaknl2csyf+HjFFwKlNxHm2MCgYAr//R3BdEy0oZeVRndo2lr9YvUEmu2LOihQpWDCd0fQw0ZDA2kc28eysL2RROte95r1XTvq6IvX5a0w11FzRWlDpQ4J4/LlcQ6LVt+98SoFwew+/PWuyLmxLycUbyMOOpm9eSc4wJJZNvaUzMCSkvfMtmm5jgyZYMMQ9A2Ul/9SQKBgB9mfh9mhBwVPIqgBJETZMMXOdxrjI5SBYHGSyJqpT+5Q0vIZLfqPrvNZOiQFzwWXPJ+tV4Mc/YorW3rZOdo6tdvEGnRO6DLTTEaByrY/io3/gcBZXoSqSuVRmxleqFdWWRnB56c1hwwWLqNHU+1671FhL6pNghFYVK4suP6qu4BAoGBAMk+VipXcIlD67mfGrET/xDqiWWBZtgTzTMjTpODhDY1GZck1eb4CQMP5j5V3gFJ4cSgWDJvnWg8rcz0unz/q4aeMGl1rah5WNDWj1QKWMS6vJhMHM/rqN1WHWR0ZnV83svYgtg0zDnQKlLujqW4JmGXLMU7ur6a+e6lpa1fvLsP" +GOTRUE_MAX_VERIFIED_FACTORS=10 +GOTRUE_SMS_TEST_OTP_VALID_UNTIL="" +GOTRUE_SECURITY_DB_ENCRYPTION_ENCRYPT=true +GOTRUE_SECURITY_DB_ENCRYPTION_ENCRYPTION_KEY_ID=abc +GOTRUE_SECURITY_DB_ENCRYPTION_ENCRYPTION_KEY=pwFoiPyybQMqNmYVN0gUnpbfpGQV2sDv9vp0ZAxi_Y4 +GOTRUE_SECURITY_DB_ENCRYPTION_DECRYPTION_KEYS=abc:pwFoiPyybQMqNmYVN0gUnpbfpGQV2sDv9vp0ZAxi_Y4 diff --git a/auth_v2.169.0/init_postgres.sh b/auth_v2.169.0/init_postgres.sh new file mode 100644 index 0000000..134e179 --- /dev/null +++ b/auth_v2.169.0/init_postgres.sh @@ -0,0 +1,12 @@ +#!/bin/bash +set -e + +psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL + CREATE USER supabase_admin LOGIN CREATEROLE CREATEDB REPLICATION BYPASSRLS; + + -- Supabase super admin + CREATE USER supabase_auth_admin NOINHERIT CREATEROLE LOGIN NOREPLICATION PASSWORD 'root'; + CREATE SCHEMA IF NOT EXISTS $DB_NAMESPACE AUTHORIZATION supabase_auth_admin; + GRANT CREATE ON DATABASE postgres TO supabase_auth_admin; + ALTER USER supabase_auth_admin SET search_path = '$DB_NAMESPACE'; +EOSQL diff --git a/auth_v2.169.0/internal/api/admin.go b/auth_v2.169.0/internal/api/admin.go new file mode 100644 index 0000000..63cde06 --- /dev/null +++ b/auth_v2.169.0/internal/api/admin.go @@ -0,0 +1,642 @@ +package api + +import ( + "context" + "net/http" + "time" + + "github.com/fatih/structs" + "github.com/go-chi/chi/v5" + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/sethvargo/go-password/password" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" + "golang.org/x/crypto/bcrypt" +) + +type AdminUserParams struct { + Id string `json:"id"` + Aud string `json:"aud"` + Role string `json:"role"` + Email string `json:"email"` + Phone string `json:"phone"` + Password *string `json:"password"` + PasswordHash string `json:"password_hash"` + EmailConfirm bool `json:"email_confirm"` + PhoneConfirm bool `json:"phone_confirm"` + UserMetaData map[string]interface{} `json:"user_metadata"` + AppMetaData map[string]interface{} `json:"app_metadata"` + BanDuration string `json:"ban_duration"` +} + +type adminUserDeleteParams struct { + ShouldSoftDelete bool `json:"should_soft_delete"` +} + +type adminUserUpdateFactorParams struct { + FriendlyName string `json:"friendly_name"` + Phone string `json:"phone"` +} + +type AdminListUsersResponse struct { + Users []*models.User `json:"users"` + Aud string `json:"aud"` +} + +func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context, error) { + ctx := r.Context() + db := a.db.WithContext(ctx) + + userID, err := uuid.FromString(chi.URLParam(r, "user_id")) + if err != nil { + return nil, notFoundError(ErrorCodeValidationFailed, "user_id must be an UUID") + } + + observability.LogEntrySetField(r, "user_id", userID) + + u, err := models.FindUserByID(db, userID) + if err != nil { + if models.IsNotFoundError(err) { + return nil, notFoundError(ErrorCodeUserNotFound, "User not found") + } + return nil, internalServerError("Database error loading user").WithInternalError(err) + } + + return withUser(ctx, u), nil +} + +// Use only after requireAuthentication, so that there is a valid user +func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Context, error) { + ctx := r.Context() + db := a.db.WithContext(ctx) + user := getUser(ctx) + factorID, err := uuid.FromString(chi.URLParam(r, "factor_id")) + if err != nil { + return nil, notFoundError(ErrorCodeValidationFailed, "factor_id must be an UUID") + } + + observability.LogEntrySetField(r, "factor_id", factorID) + + factor, err := user.FindOwnedFactorByID(db, factorID) + if err != nil { + if models.IsNotFoundError(err) { + return nil, notFoundError(ErrorCodeMFAFactorNotFound, "Factor not found") + } + return nil, internalServerError("Database error loading factor").WithInternalError(err) + } + return withFactor(ctx, factor), nil +} + +func (a *API) getAdminParams(r *http.Request) (*AdminUserParams, error) { + params := &AdminUserParams{} + if err := retrieveRequestParams(r, params); err != nil { + return nil, err + } + + return params, nil +} + +// adminUsers responds with a list of all users in a given audience +func (a *API) adminUsers(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + aud := a.requestAud(ctx, r) + + pageParams, err := paginate(r) + if err != nil { + return badRequestError(ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err).WithInternalError(err) + } + + sortParams, err := sort(r, map[string]bool{models.CreatedAt: true}, []models.SortField{{Name: models.CreatedAt, Dir: models.Descending}}) + if err != nil { + return badRequestError(ErrorCodeValidationFailed, "Bad Sort Parameters: %v", err) + } + + filter := r.URL.Query().Get("filter") + + users, err := models.FindUsersInAudience(db, aud, pageParams, sortParams, filter) + if err != nil { + return internalServerError("Database error finding users").WithInternalError(err) + } + addPaginationHeaders(w, r, pageParams) + + return sendJSON(w, http.StatusOK, AdminListUsersResponse{ + Users: users, + Aud: aud, + }) +} + +// adminUserGet returns information about a single user +func (a *API) adminUserGet(w http.ResponseWriter, r *http.Request) error { + user := getUser(r.Context()) + + return sendJSON(w, http.StatusOK, user) +} + +// adminUserUpdate updates a single user object +func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + config := a.config + user := getUser(ctx) + adminUser := getAdminUser(ctx) + params, err := a.getAdminParams(r) + if err != nil { + return err + } + + if params.Email != "" { + params.Email, err = a.validateEmail(params.Email) + if err != nil { + return err + } + } + + if params.Phone != "" { + params.Phone, err = validatePhone(params.Phone) + if err != nil { + return err + } + } + + var banDuration *time.Duration + if params.BanDuration != "" { + duration := time.Duration(0) + if params.BanDuration != "none" { + duration, err = time.ParseDuration(params.BanDuration) + if err != nil { + return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err) + } + } + banDuration = &duration + } + + if params.Password != nil { + password := *params.Password + + if err := a.checkPasswordStrength(ctx, password); err != nil { + return err + } + + if err := user.SetPassword(ctx, password, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey); err != nil { + return err + } + } + + err = db.Transaction(func(tx *storage.Connection) error { + if params.Role != "" { + if terr := user.SetRole(tx, params.Role); terr != nil { + return terr + } + } + + if params.EmailConfirm { + if terr := user.Confirm(tx); terr != nil { + return terr + } + } + + if params.PhoneConfirm { + if terr := user.ConfirmPhone(tx); terr != nil { + return terr + } + } + + if params.Password != nil { + if terr := user.UpdatePassword(tx, nil); terr != nil { + return terr + } + } + + var identities []models.Identity + if params.Email != "" { + if identity, terr := models.FindIdentityByIdAndProvider(tx, user.ID.String(), "email"); terr != nil && !models.IsNotFoundError(terr) { + return terr + } else if identity == nil { + // if the user doesn't have an existing email + // then updating the user's email should create a new email identity + i, terr := a.createNewIdentity(tx, user, "email", structs.Map(provider.Claims{ + Subject: user.ID.String(), + Email: params.Email, + EmailVerified: params.EmailConfirm, + })) + if terr != nil { + return terr + } + identities = append(identities, *i) + } else { + // update the existing email identity + if terr := identity.UpdateIdentityData(tx, map[string]interface{}{ + "email": params.Email, + "email_verified": params.EmailConfirm, + }); terr != nil { + return terr + } + } + if user.IsAnonymous && params.EmailConfirm { + user.IsAnonymous = false + if terr := tx.UpdateOnly(user, "is_anonymous"); terr != nil { + return terr + } + } + + if terr := user.SetEmail(tx, params.Email); terr != nil { + return terr + } + } + + if params.Phone != "" { + if identity, terr := models.FindIdentityByIdAndProvider(tx, user.ID.String(), "phone"); terr != nil && !models.IsNotFoundError(terr) { + return terr + } else if identity == nil { + // if the user doesn't have an existing phone + // then updating the user's phone should create a new phone identity + identity, terr := a.createNewIdentity(tx, user, "phone", structs.Map(provider.Claims{ + Subject: user.ID.String(), + Phone: params.Phone, + PhoneVerified: params.PhoneConfirm, + })) + if terr != nil { + return terr + } + identities = append(identities, *identity) + } else { + // update the existing phone identity + if terr := identity.UpdateIdentityData(tx, map[string]interface{}{ + "phone": params.Phone, + "phone_verified": params.PhoneConfirm, + }); terr != nil { + return terr + } + } + if user.IsAnonymous && params.PhoneConfirm { + user.IsAnonymous = false + if terr := tx.UpdateOnly(user, "is_anonymous"); terr != nil { + return terr + } + } + if terr := user.SetPhone(tx, params.Phone); terr != nil { + return terr + } + } + user.Identities = append(user.Identities, identities...) + + if params.AppMetaData != nil { + if terr := user.UpdateAppMetaData(tx, params.AppMetaData); terr != nil { + return terr + } + } + + if params.UserMetaData != nil { + if terr := user.UpdateUserMetaData(tx, params.UserMetaData); terr != nil { + return terr + } + } + + if banDuration != nil { + if terr := user.Ban(tx, *banDuration); terr != nil { + return terr + } + } + + if terr := models.NewAuditLogEntry(r, tx, adminUser, models.UserModifiedAction, "", map[string]interface{}{ + "user_id": user.ID, + "user_email": user.Email, + "user_phone": user.Phone, + }); terr != nil { + return terr + } + return nil + }) + + if err != nil { + return internalServerError("Error updating user").WithInternalError(err) + } + + return sendJSON(w, http.StatusOK, user) +} + +// adminUserCreate creates a new user based on the provided data +func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + config := a.config + + adminUser := getAdminUser(ctx) + params, err := a.getAdminParams(r) + if err != nil { + return err + } + + aud := a.requestAud(ctx, r) + if params.Aud != "" { + aud = params.Aud + } + + if params.Email == "" && params.Phone == "" { + return badRequestError(ErrorCodeValidationFailed, "Cannot create a user without either an email or phone") + } + + var providers []string + if params.Email != "" { + params.Email, err = a.validateEmail(params.Email) + if err != nil { + return err + } + if user, err := models.IsDuplicatedEmail(db, params.Email, aud, nil); err != nil { + return internalServerError("Database error checking email").WithInternalError(err) + } else if user != nil { + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) + } + providers = append(providers, "email") + } + + if params.Phone != "" { + params.Phone, err = validatePhone(params.Phone) + if err != nil { + return err + } + if exists, err := models.IsDuplicatedPhone(db, params.Phone, aud); err != nil { + return internalServerError("Database error checking phone").WithInternalError(err) + } else if exists { + return unprocessableEntityError(ErrorCodePhoneExists, "Phone number already registered by another user") + } + providers = append(providers, "phone") + } + + if params.Password != nil && params.PasswordHash != "" { + return badRequestError(ErrorCodeValidationFailed, "Only a password or a password hash should be provided") + } + + if (params.Password == nil || *params.Password == "") && params.PasswordHash == "" { + password, err := password.Generate(64, 10, 0, false, true) + if err != nil { + return internalServerError("Error generating password").WithInternalError(err) + } + params.Password = &password + } + + var user *models.User + if params.PasswordHash != "" { + user, err = models.NewUserWithPasswordHash(params.Phone, params.Email, params.PasswordHash, aud, params.UserMetaData) + } else { + user, err = models.NewUser(params.Phone, params.Email, *params.Password, aud, params.UserMetaData) + } + + if err != nil { + if errors.Is(err, bcrypt.ErrPasswordTooLong) { + return badRequestError(ErrorCodeValidationFailed, err.Error()) + } + return internalServerError("Error creating user").WithInternalError(err) + } + + if params.Id != "" { + customId, err := uuid.FromString(params.Id) + if err != nil { + return badRequestError(ErrorCodeValidationFailed, "ID must conform to the uuid v4 format") + } + if customId == uuid.Nil { + return badRequestError(ErrorCodeValidationFailed, "ID cannot be a nil uuid") + } + user.ID = customId + } + + user.AppMetaData = map[string]interface{}{ + // TODO: Deprecate "provider" field + // default to the first provider in the providers slice + "provider": providers[0], + "providers": providers, + } + + var banDuration *time.Duration + if params.BanDuration != "" { + duration := time.Duration(0) + if params.BanDuration != "none" { + duration, err = time.ParseDuration(params.BanDuration) + if err != nil { + return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err) + } + } + banDuration = &duration + } + + err = db.Transaction(func(tx *storage.Connection) error { + if terr := tx.Create(user); terr != nil { + return terr + } + + var identities []models.Identity + if user.GetEmail() != "" { + identity, terr := a.createNewIdentity(tx, user, "email", structs.Map(provider.Claims{ + Subject: user.ID.String(), + Email: user.GetEmail(), + })) + + if terr != nil { + return terr + } + identities = append(identities, *identity) + } + + if user.GetPhone() != "" { + identity, terr := a.createNewIdentity(tx, user, "phone", structs.Map(provider.Claims{ + Subject: user.ID.String(), + Phone: user.GetPhone(), + })) + + if terr != nil { + return terr + } + identities = append(identities, *identity) + } + + user.Identities = identities + + if terr := models.NewAuditLogEntry(r, tx, adminUser, models.UserSignedUpAction, "", map[string]interface{}{ + "user_id": user.ID, + "user_email": user.Email, + "user_phone": user.Phone, + }); terr != nil { + return terr + } + + role := config.JWT.DefaultGroupName + if params.Role != "" { + role = params.Role + } + if terr := user.SetRole(tx, role); terr != nil { + return terr + } + + if params.AppMetaData != nil { + if terr := user.UpdateAppMetaData(tx, params.AppMetaData); terr != nil { + return terr + } + } + + if params.EmailConfirm { + if terr := user.Confirm(tx); terr != nil { + return terr + } + } + + if params.PhoneConfirm { + if terr := user.ConfirmPhone(tx); terr != nil { + return terr + } + } + + if banDuration != nil { + if terr := user.Ban(tx, *banDuration); terr != nil { + return terr + } + } + + return nil + }) + + if err != nil { + return internalServerError("Database error creating new user").WithInternalError(err) + } + + return sendJSON(w, http.StatusOK, user) +} + +// adminUserDelete deletes a user +func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + user := getUser(ctx) + adminUser := getAdminUser(ctx) + + // ShouldSoftDelete defaults to false + params := &adminUserDeleteParams{} + if body, _ := utilities.GetBodyBytes(r); len(body) != 0 { + // we only want to parse the body if it's not empty + // retrieveRequestParams will handle any errors with stream + if err := retrieveRequestParams(r, params); err != nil { + return err + } + } + + err := a.db.Transaction(func(tx *storage.Connection) error { + if terr := models.NewAuditLogEntry(r, tx, adminUser, models.UserDeletedAction, "", map[string]interface{}{ + "user_id": user.ID, + "user_email": user.Email, + "user_phone": user.Phone, + }); terr != nil { + return internalServerError("Error recording audit log entry").WithInternalError(terr) + } + + if params.ShouldSoftDelete { + if user.DeletedAt != nil { + // user has been soft deleted already + return nil + } + if terr := user.SoftDeleteUser(tx); terr != nil { + return internalServerError("Error soft deleting user").WithInternalError(terr) + } + + if terr := user.SoftDeleteUserIdentities(tx); terr != nil { + return internalServerError("Error soft deleting user identities").WithInternalError(terr) + } + + // hard delete all associated factors + if terr := models.DeleteFactorsByUserId(tx, user.ID); terr != nil { + return internalServerError("Error deleting user's factors").WithInternalError(terr) + } + // hard delete all associated sessions + if terr := models.Logout(tx, user.ID); terr != nil { + return internalServerError("Error deleting user's sessions").WithInternalError(terr) + } + } else { + if terr := tx.Destroy(user); terr != nil { + return internalServerError("Database error deleting user").WithInternalError(terr) + } + } + + return nil + }) + if err != nil { + return err + } + + return sendJSON(w, http.StatusOK, map[string]interface{}{}) +} + +func (a *API) adminUserDeleteFactor(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + user := getUser(ctx) + factor := getFactor(ctx) + + err := a.db.Transaction(func(tx *storage.Connection) error { + if terr := models.NewAuditLogEntry(r, tx, user, models.DeleteFactorAction, r.RemoteAddr, map[string]interface{}{ + "user_id": user.ID, + "factor_id": factor.ID, + }); terr != nil { + return terr + } + if terr := tx.Destroy(factor); terr != nil { + return internalServerError("Database error deleting factor").WithInternalError(terr) + } + return nil + }) + if err != nil { + return err + } + return sendJSON(w, http.StatusOK, factor) +} + +func (a *API) adminUserGetFactors(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + user := getUser(ctx) + return sendJSON(w, http.StatusOK, user.Factors) +} + +// adminUserUpdate updates a single factor object +func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + factor := getFactor(ctx) + user := getUser(ctx) + adminUser := getAdminUser(ctx) + params := &adminUserUpdateFactorParams{} + + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + err := a.db.Transaction(func(tx *storage.Connection) error { + if params.FriendlyName != "" { + if terr := factor.UpdateFriendlyName(tx, params.FriendlyName); terr != nil { + return terr + } + } + + if params.Phone != "" && factor.IsPhoneFactor() { + phone, err := validatePhone(params.Phone) + if err != nil { + return badRequestError(ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)") + } + if terr := factor.UpdatePhone(tx, phone); terr != nil { + return terr + } + } + + if terr := models.NewAuditLogEntry(r, tx, adminUser, models.UpdateFactorAction, "", map[string]interface{}{ + "user_id": user.ID, + "factor_id": factor.ID, + "factor_type": factor.FactorType, + }); terr != nil { + return terr + } + return nil + }) + if err != nil { + return err + } + + return sendJSON(w, http.StatusOK, factor) +} diff --git a/auth_v2.169.0/internal/api/admin_test.go b/auth_v2.169.0/internal/api/admin_test.go new file mode 100644 index 0000000..a2070d7 --- /dev/null +++ b/auth_v2.169.0/internal/api/admin_test.go @@ -0,0 +1,915 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gofrs/uuid" + jwt "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +type AdminTestSuite struct { + suite.Suite + User *models.User + API *API + Config *conf.GlobalConfiguration + + token string +} + +func TestAdmin(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &AdminTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *AdminTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + ts.Config.External.Email.Enabled = true + claims := &AccessTokenClaims{ + Role: "supabase_admin", + } + token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(ts.Config.JWT.Secret)) + require.NoError(ts.T(), err, "Error generating admin jwt") + ts.token = token +} + +// TestAdminUsersUnauthorized tests API /admin/users route without authentication +func (ts *AdminTestSuite) TestAdminUsersUnauthorized() { + req := httptest.NewRequest(http.MethodGet, "/admin/users", nil) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusUnauthorized, w.Code) +} + +// TestAdminUsers tests API /admin/users route +func (ts *AdminTestSuite) TestAdminUsers() { + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/admin/users", nil) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + assert.Equal(ts.T(), "; rel=\"last\"", w.Header().Get("Link")) + assert.Equal(ts.T(), "0", w.Header().Get("X-Total-Count")) +} + +// TestAdminUsers tests API /admin/users route +func (ts *AdminTestSuite) TestAdminUsers_Pagination() { + u, err := models.NewUser("12345678", "test1@example.com", "test", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + u, err = models.NewUser("987654321", "test2@example.com", "test", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/admin/users?per_page=1", nil) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + assert.Equal(ts.T(), "; rel=\"next\", ; rel=\"last\"", w.Header().Get("Link")) + assert.Equal(ts.T(), "2", w.Header().Get("X-Total-Count")) + + data := make(map[string]interface{}) + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + for _, user := range data["users"].([]interface{}) { + assert.NotEmpty(ts.T(), user) + } +} + +// TestAdminUsers tests API /admin/users route +func (ts *AdminTestSuite) TestAdminUsers_SortAsc() { + u, err := models.NewUser("", "test1@example.com", "test", ts.Config.JWT.Aud, nil) + u.CreatedAt = time.Now().Add(-time.Minute) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + u, err = models.NewUser("", "test2@example.com", "test", ts.Config.JWT.Aud, nil) + u.CreatedAt = time.Now() + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/admin/users", nil) + qv := req.URL.Query() + qv.Set("sort", "created_at asc") + req.URL.RawQuery = qv.Encode() + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := struct { + Users []*models.User `json:"users"` + Aud string `json:"aud"` + }{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + require.Len(ts.T(), data.Users, 2) + assert.Equal(ts.T(), "test1@example.com", data.Users[0].GetEmail()) + assert.Equal(ts.T(), "test2@example.com", data.Users[1].GetEmail()) +} + +// TestAdminUsers tests API /admin/users route +func (ts *AdminTestSuite) TestAdminUsers_SortDesc() { + u, err := models.NewUser("12345678", "test1@example.com", "test", ts.Config.JWT.Aud, nil) + u.CreatedAt = time.Now().Add(-time.Minute) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + u, err = models.NewUser("987654321", "test2@example.com", "test", ts.Config.JWT.Aud, nil) + u.CreatedAt = time.Now() + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/admin/users", nil) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := struct { + Users []*models.User `json:"users"` + Aud string `json:"aud"` + }{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + require.Len(ts.T(), data.Users, 2) + assert.Equal(ts.T(), "test2@example.com", data.Users[0].GetEmail()) + assert.Equal(ts.T(), "test1@example.com", data.Users[1].GetEmail()) +} + +// TestAdminUsers tests API /admin/users route +func (ts *AdminTestSuite) TestAdminUsers_FilterEmail() { + u, err := models.NewUser("", "test1@example.com", "test", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/admin/users?filter=test1", nil) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := struct { + Users []*models.User `json:"users"` + Aud string `json:"aud"` + }{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + require.Len(ts.T(), data.Users, 1) + assert.Equal(ts.T(), "test1@example.com", data.Users[0].GetEmail()) +} + +// TestAdminUsers tests API /admin/users route +func (ts *AdminTestSuite) TestAdminUsers_FilterName() { + u, err := models.NewUser("", "test1@example.com", "test", ts.Config.JWT.Aud, map[string]interface{}{"full_name": "Test User"}) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + u, err = models.NewUser("", "test2@example.com", "test", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/admin/users?filter=User", nil) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := struct { + Users []*models.User `json:"users"` + Aud string `json:"aud"` + }{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + require.Len(ts.T(), data.Users, 1) + assert.Equal(ts.T(), "test1@example.com", data.Users[0].GetEmail()) +} + +// TestAdminUserCreate tests API /admin/user route (POST) +func (ts *AdminTestSuite) TestAdminUserCreate() { + cases := []struct { + desc string + params map[string]interface{} + expected map[string]interface{} + }{ + { + desc: "Only phone", + params: map[string]interface{}{ + "phone": "123456789", + "password": "test1", + }, + expected: map[string]interface{}{ + "email": "", + "phone": "123456789", + "isAuthenticated": true, + "provider": "phone", + "providers": []string{"phone"}, + "password": "test1", + }, + }, + { + desc: "With password", + params: map[string]interface{}{ + "email": "test1@example.com", + "phone": "123456789", + "password": "test1", + }, + expected: map[string]interface{}{ + "email": "test1@example.com", + "phone": "123456789", + "isAuthenticated": true, + "provider": "email", + "providers": []string{"email", "phone"}, + "password": "test1", + }, + }, + { + desc: "Without password", + params: map[string]interface{}{ + "email": "test2@example.com", + "phone": "", + }, + expected: map[string]interface{}{ + "email": "test2@example.com", + "phone": "", + "isAuthenticated": false, + "provider": "email", + "providers": []string{"email"}, + }, + }, + { + desc: "With empty string password", + params: map[string]interface{}{ + "email": "test3@example.com", + "phone": "", + "password": "", + }, + expected: map[string]interface{}{ + "email": "test3@example.com", + "phone": "", + "isAuthenticated": false, + "provider": "email", + "providers": []string{"email"}, + "password": "", + }, + }, + { + desc: "Ban created user", + params: map[string]interface{}{ + "email": "test4@example.com", + "phone": "", + "password": "test1", + "ban_duration": "24h", + }, + expected: map[string]interface{}{ + "email": "test4@example.com", + "phone": "", + "isAuthenticated": true, + "provider": "email", + "providers": []string{"email"}, + "password": "test1", + }, + }, + { + desc: "With password hash", + params: map[string]interface{}{ + "email": "test5@example.com", + "password_hash": "$2y$10$SXEz2HeT8PUIGQXo9yeUIem8KzNxgG0d7o/.eGj2rj8KbRgAuRVlq", + }, + expected: map[string]interface{}{ + "email": "test5@example.com", + "phone": "", + "isAuthenticated": true, + "provider": "email", + "providers": []string{"email"}, + "password": "test", + }, + }, + { + desc: "With custom id", + params: map[string]interface{}{ + "id": "fc56ab41-2010-4870-a9b9-767c1dc573fb", + "email": "test6@example.com", + "password": "test", + }, + expected: map[string]interface{}{ + "id": "fc56ab41-2010-4870-a9b9-767c1dc573fb", + "email": "test6@example.com", + "phone": "", + "isAuthenticated": true, + "provider": "email", + "providers": []string{"email"}, + "password": "test", + }, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.params)) + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/admin/users", &buffer) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + ts.Config.External.Phone.Enabled = true + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := models.User{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + assert.Equal(ts.T(), c.expected["email"], data.GetEmail()) + assert.Equal(ts.T(), c.expected["phone"], data.GetPhone()) + assert.Equal(ts.T(), c.expected["provider"], data.AppMetaData["provider"]) + assert.ElementsMatch(ts.T(), c.expected["providers"], data.AppMetaData["providers"]) + + u, err := models.FindUserByID(ts.API.db, data.ID) + require.NoError(ts.T(), err) + + // verify that the corresponding identities were created + require.NotEmpty(ts.T(), u.Identities) + for _, identity := range u.Identities { + require.Equal(ts.T(), u.ID, identity.UserID) + if identity.Provider == "email" { + require.Equal(ts.T(), c.expected["email"], identity.IdentityData["email"]) + } + if identity.Provider == "phone" { + require.Equal(ts.T(), c.expected["phone"], identity.IdentityData["phone"]) + } + } + + if _, ok := c.expected["password"]; ok { + expectedPassword := fmt.Sprintf("%v", c.expected["password"]) + isAuthenticated, _, err := u.Authenticate(context.Background(), ts.API.db, expectedPassword, ts.API.config.Security.DBEncryption.DecryptionKeys, ts.API.config.Security.DBEncryption.Encrypt, ts.API.config.Security.DBEncryption.EncryptionKeyID) + require.NoError(ts.T(), err) + require.Equal(ts.T(), c.expected["isAuthenticated"], isAuthenticated) + } + + if id, ok := c.expected["id"]; ok { + uid, err := uuid.FromString(id.(string)) + require.NoError(ts.T(), err) + require.Equal(ts.T(), uid, data.ID) + } + + // remove created user after each case + require.NoError(ts.T(), ts.API.db.Destroy(u)) + }) + } +} + +// TestAdminUserGet tests API /admin/user route (GET) +func (ts *AdminTestSuite) TestAdminUserGet() { + u, err := models.NewUser("12345678", "test1@example.com", "test", ts.Config.JWT.Aud, map[string]interface{}{"full_name": "Test Get User"}) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/admin/users/%s", u.ID), nil) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := make(map[string]interface{}) + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + assert.Equal(ts.T(), data["email"], "test1@example.com") + assert.NotNil(ts.T(), data["app_metadata"]) + assert.NotNil(ts.T(), data["user_metadata"]) + md := data["user_metadata"].(map[string]interface{}) + assert.Len(ts.T(), md, 1) + assert.Equal(ts.T(), "Test Get User", md["full_name"]) +} + +// TestAdminUserUpdate tests API /admin/user route (UPDATE) +func (ts *AdminTestSuite) TestAdminUserUpdate() { + u, err := models.NewUser("12345678", "test1@example.com", "test", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + var buffer bytes.Buffer + newEmail := "test2@example.com" + newPhone := "234567890" + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "role": "testing", + "app_metadata": map[string]interface{}{ + "roles": []string{"writer", "editor"}, + }, + "user_metadata": map[string]interface{}{ + "name": "David", + }, + "ban_duration": "24h", + "email": newEmail, + "phone": newPhone, + })) + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/admin/users/%s", u.ID), &buffer) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := models.User{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + assert.Equal(ts.T(), "testing", data.Role) + assert.NotNil(ts.T(), data.UserMetaData) + assert.Equal(ts.T(), "David", data.UserMetaData["name"]) + assert.Equal(ts.T(), newEmail, data.GetEmail()) + assert.Equal(ts.T(), newPhone, data.GetPhone()) + + assert.NotNil(ts.T(), data.AppMetaData) + assert.Len(ts.T(), data.AppMetaData["roles"], 2) + assert.Contains(ts.T(), data.AppMetaData["roles"], "writer") + assert.Contains(ts.T(), data.AppMetaData["roles"], "editor") + assert.NotNil(ts.T(), data.BannedUntil) + + u, err = models.FindUserByID(ts.API.db, data.ID) + require.NoError(ts.T(), err) + + // check if the corresponding identities were successfully created + require.NotEmpty(ts.T(), u.Identities) + + for _, identity := range u.Identities { + // for email & phone identities, the providerId is the same as the userId + require.Equal(ts.T(), u.ID.String(), identity.ProviderID) + require.Equal(ts.T(), u.ID, identity.UserID) + if identity.Provider == "email" { + require.Equal(ts.T(), newEmail, identity.IdentityData["email"]) + } + if identity.Provider == "phone" { + require.Equal(ts.T(), newPhone, identity.IdentityData["phone"]) + + } + } +} + +func (ts *AdminTestSuite) TestAdminUserUpdatePasswordFailed() { + u, err := models.NewUser("12345678", "test1@example.com", "test", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + var updateEndpoint = fmt.Sprintf("/admin/users/%s", u.ID) + ts.Config.Password.MinLength = 6 + ts.Run("Password doesn't meet minimum length", func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "password": "", + })) + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, updateEndpoint, &buffer) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusUnprocessableEntity, w.Code) + }) +} + +func (ts *AdminTestSuite) TestAdminUserUpdateBannedUntilFailed() { + u, err := models.NewUser("", "test1@example.com", "test", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + var updateEndpoint = fmt.Sprintf("/admin/users/%s", u.ID) + ts.Config.Password.MinLength = 6 + ts.Run("Incorrect format for ban_duration", func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "ban_duration": "24", + })) + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, updateEndpoint, &buffer) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusBadRequest, w.Code) + }) +} + +// TestAdminUserDelete tests API /admin/users route (DELETE) +func (ts *AdminTestSuite) TestAdminUserDelete() { + type expected struct { + code int + err error + } + signupParams := &SignupParams{ + Email: "test-delete@example.com", + Password: "test", + Data: map[string]interface{}{"name": "test"}, + Provider: "email", + Aud: ts.Config.JWT.Aud, + } + cases := []struct { + desc string + body map[string]interface{} + isSoftDelete string + isSSOUser bool + expected expected + }{ + { + desc: "Test admin delete user (default)", + isSoftDelete: "", + isSSOUser: false, + expected: expected{code: http.StatusOK, err: models.UserNotFoundError{}}, + body: nil, + }, + { + desc: "Test admin delete user (hard deletion)", + isSoftDelete: "?is_soft_delete=false", + isSSOUser: false, + expected: expected{code: http.StatusOK, err: models.UserNotFoundError{}}, + body: map[string]interface{}{ + "should_soft_delete": false, + }, + }, + { + desc: "Test admin delete user (soft deletion)", + isSoftDelete: "?is_soft_delete=true", + isSSOUser: false, + expected: expected{code: http.StatusOK, err: models.UserNotFoundError{}}, + body: map[string]interface{}{ + "should_soft_delete": true, + }, + }, + { + desc: "Test admin delete user (soft deletion & sso user)", + isSoftDelete: "?is_soft_delete=true", + isSSOUser: true, + expected: expected{code: http.StatusOK, err: nil}, + body: map[string]interface{}{ + "should_soft_delete": true, + }, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + u, err := signupParams.ToUserModel(false /* <- isSSOUser */) + require.NoError(ts.T(), err) + u, err = ts.API.signupNewUser(ts.API.db, u) + require.NoError(ts.T(), err) + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/admin/users/%s", u.ID), &buffer) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), c.expected.code, w.Code) + + if c.isSSOUser { + u, err = models.FindUserByID(ts.API.db, u.ID) + require.NotNil(ts.T(), u) + } else { + _, err = models.FindUserByEmailAndAudience(ts.API.db, signupParams.Email, ts.Config.JWT.Aud) + } + require.Equal(ts.T(), c.expected.err, err) + }) + } +} + +func (ts *AdminTestSuite) TestAdminUserSoftDeletion() { + // create user + u, err := models.NewUser("123456789", "test@example.com", "secret", ts.Config.JWT.Aud, map[string]interface{}{"name": "test"}) + require.NoError(ts.T(), err) + u.ConfirmationToken = "some_token" + u.RecoveryToken = "some_token" + u.EmailChangeTokenCurrent = "some_token" + u.EmailChangeTokenNew = "some_token" + u.PhoneChangeToken = "some_token" + u.AppMetaData = map[string]interface{}{ + "provider": "email", + } + require.NoError(ts.T(), ts.API.db.Create(u)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.RecoveryToken, models.RecoveryToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.EmailChangeTokenCurrent, models.EmailChangeTokenCurrent)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.EmailChangeTokenNew, models.EmailChangeTokenNew)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetPhone(), u.PhoneChangeToken, models.PhoneChangeToken)) + + // create user identities + _, err = ts.API.createNewIdentity(ts.API.db, u, "email", map[string]interface{}{ + "sub": "123456", + "email": "test@example.com", + }) + require.NoError(ts.T(), err) + _, err = ts.API.createNewIdentity(ts.API.db, u, "github", map[string]interface{}{ + "sub": "234567", + "email": "test@example.com", + }) + require.NoError(ts.T(), err) + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "should_soft_delete": true, + })) + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/admin/users/%s", u.ID), &buffer) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + // get soft-deleted user from db + deletedUser, err := models.FindUserByID(ts.API.db, u.ID) + require.NoError(ts.T(), err) + + require.Empty(ts.T(), deletedUser.ConfirmationToken) + require.Empty(ts.T(), deletedUser.RecoveryToken) + require.Empty(ts.T(), deletedUser.EmailChangeTokenCurrent) + require.Empty(ts.T(), deletedUser.EmailChangeTokenNew) + require.Empty(ts.T(), deletedUser.EncryptedPassword) + require.Empty(ts.T(), deletedUser.PhoneChangeToken) + require.Empty(ts.T(), deletedUser.UserMetaData) + require.Empty(ts.T(), deletedUser.AppMetaData) + require.NotEmpty(ts.T(), deletedUser.DeletedAt) + require.NotEmpty(ts.T(), deletedUser.GetEmail()) + + // get soft-deleted user's identity from db + deletedIdentities, err := models.FindIdentitiesByUserID(ts.API.db, deletedUser.ID) + require.NoError(ts.T(), err) + + for _, identity := range deletedIdentities { + require.Empty(ts.T(), identity.IdentityData) + } +} + +func (ts *AdminTestSuite) TestAdminUserCreateWithDisabledLogin() { + var cases = []struct { + desc string + customConfig *conf.GlobalConfiguration + userData map[string]interface{} + expected int + }{ + { + desc: "Email Signups Disabled", + customConfig: &conf.GlobalConfiguration{ + JWT: ts.Config.JWT, + External: conf.ProviderConfiguration{ + Email: conf.EmailProviderConfiguration{ + Enabled: false, + }, + }, + }, + userData: map[string]interface{}{ + "email": "test1@example.com", + "password": "test1", + }, + expected: http.StatusOK, + }, + { + desc: "Phone Signups Disabled", + customConfig: &conf.GlobalConfiguration{ + JWT: ts.Config.JWT, + External: conf.ProviderConfiguration{ + Phone: conf.PhoneProviderConfiguration{ + Enabled: false, + }, + }, + }, + userData: map[string]interface{}{ + "phone": "123456789", + "password": "test1", + }, + expected: http.StatusOK, + }, + { + desc: "All Signups Disabled", + customConfig: &conf.GlobalConfiguration{ + JWT: ts.Config.JWT, + DisableSignup: true, + }, + userData: map[string]interface{}{ + "email": "test2@example.com", + "password": "test2", + }, + expected: http.StatusOK, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + // Initialize user data + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.userData)) + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/admin/users", &buffer) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.Config.JWT = c.customConfig.JWT + ts.Config.External = c.customConfig.External + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), c.expected, w.Code) + }) + } +} + +// TestAdminUserDeleteFactor tests API /admin/users//factors// +func (ts *AdminTestSuite) TestAdminUserDeleteFactor() { + u, err := models.NewUser("123456789", "test-delete@example.com", "test", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + f := models.NewTOTPFactor(u, "testSimpleName") + require.NoError(ts.T(), f.UpdateStatus(ts.API.db, models.FactorStateVerified)) + require.NoError(ts.T(), f.SetSecret("secretkey", ts.Config.Security.DBEncryption.Encrypt, ts.Config.Security.DBEncryption.EncryptionKeyID, ts.Config.Security.DBEncryption.EncryptionKey)) + require.NoError(ts.T(), ts.API.db.Create(f), "Error saving new test factor") + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/admin/users/%s/factors/%s/", u.ID, f.ID), nil) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + _, err = models.FindFactorByFactorID(ts.API.db, f.ID) + require.EqualError(ts.T(), err, models.FactorNotFoundError{}.Error()) + +} + +// TestAdminUserGetFactor tests API /admin/user//factors/ +func (ts *AdminTestSuite) TestAdminUserGetFactors() { + u, err := models.NewUser("123456789", "test-delete@example.com", "test", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + f := models.NewTOTPFactor(u, "testSimpleName") + require.NoError(ts.T(), f.SetSecret("secretkey", ts.Config.Security.DBEncryption.Encrypt, ts.Config.Security.DBEncryption.EncryptionKeyID, ts.Config.Security.DBEncryption.EncryptionKey)) + require.NoError(ts.T(), ts.API.db.Create(f), "Error saving new test factor") + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/admin/users/%s/factors/", u.ID), nil) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + getFactorsResp := []*models.Factor{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&getFactorsResp)) + require.Equal(ts.T(), getFactorsResp[0].Secret, "") +} + +func (ts *AdminTestSuite) TestAdminUserUpdateFactor() { + u, err := models.NewUser("123456789", "test-delete@example.com", "test", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + f := models.NewPhoneFactor(u, "123456789", "testSimpleName") + require.NoError(ts.T(), f.SetSecret("secretkey", ts.Config.Security.DBEncryption.Encrypt, ts.Config.Security.DBEncryption.EncryptionKeyID, ts.Config.Security.DBEncryption.EncryptionKey)) + require.NoError(ts.T(), ts.API.db.Create(f), "Error saving new test factor") + + var cases = []struct { + Desc string + FactorData map[string]interface{} + ExpectedCode int + }{ + { + Desc: "Update Factor friendly name", + FactorData: map[string]interface{}{ + "friendly_name": "john", + }, + ExpectedCode: http.StatusOK, + }, + { + Desc: "Update Factor phone number", + FactorData: map[string]interface{}{ + "phone": "+1976154321", + }, + ExpectedCode: http.StatusOK, + }, + } + + // Initialize factor data + for _, c := range cases { + ts.Run(c.Desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.FactorData)) + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/admin/users/%s/factors/%s/", u.ID, f.ID), &buffer) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), c.ExpectedCode, w.Code) + }) + } +} + +func (ts *AdminTestSuite) TestAdminUserCreateValidationErrors() { + cases := []struct { + desc string + params map[string]interface{} + }{ + { + desc: "create user without email and phone", + params: map[string]interface{}{ + "password": "test_password", + }, + }, + { + desc: "create user with password and password hash", + params: map[string]interface{}{ + "email": "test@example.com", + "password": "test_password", + "password_hash": "$2y$10$Tk6yEdmTbb/eQ/haDMaCsuCsmtPVprjHMcij1RqiJdLGPDXnL3L1a", + }, + }, + { + desc: "invalid ban duration", + params: map[string]interface{}{ + "email": "test@example.com", + "ban_duration": "never", + }, + }, + { + desc: "custom id is nil", + params: map[string]interface{}{ + "id": "00000000-0000-0000-0000-000000000000", + "email": "test@example.com", + }, + }, + { + desc: "bad id format", + params: map[string]interface{}{ + "id": "bad_uuid_format", + "email": "test@example.com", + }, + }, + } + for _, c := range cases { + ts.Run(c.desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.params)) + req := httptest.NewRequest(http.MethodPost, "/admin/users", &buffer) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusBadRequest, w.Code, w) + + data := map[string]interface{}{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.Equal(ts.T(), data["error_code"], ErrorCodeValidationFailed) + }) + + } +} diff --git a/auth_v2.169.0/internal/api/anonymous.go b/auth_v2.169.0/internal/api/anonymous.go new file mode 100644 index 0000000..294f860 --- /dev/null +++ b/auth_v2.169.0/internal/api/anonymous.go @@ -0,0 +1,55 @@ +package api + +import ( + "net/http" + + "github.com/supabase/auth/internal/metering" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + config := a.config + db := a.db.WithContext(ctx) + aud := a.requestAud(ctx, r) + + if config.DisableSignup { + return unprocessableEntityError(ErrorCodeSignupDisabled, "Signups not allowed for this instance") + } + + params := &SignupParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + params.Aud = aud + params.Provider = "anonymous" + + newUser, err := params.ToUserModel(false /* <- isSSOUser */) + if err != nil { + return err + } + + var grantParams models.GrantParams + grantParams.FillGrantParams(r) + + var token *AccessTokenResponse + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + newUser, terr = a.signupNewUser(tx, newUser) + if terr != nil { + return terr + } + token, terr = a.issueRefreshToken(r, tx, newUser, models.Anonymous, grantParams) + if terr != nil { + return terr + } + return nil + }) + if err != nil { + return internalServerError("Database error creating anonymous user").WithInternalError(err) + } + + metering.RecordLogin("anonymous", newUser.ID) + return sendJSON(w, http.StatusOK, token) +} diff --git a/auth_v2.169.0/internal/api/anonymous_test.go b/auth_v2.169.0/internal/api/anonymous_test.go new file mode 100644 index 0000000..81d900d --- /dev/null +++ b/auth_v2.169.0/internal/api/anonymous_test.go @@ -0,0 +1,329 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gofrs/uuid" + jwt "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + mail "github.com/supabase/auth/internal/mailer" + "github.com/supabase/auth/internal/models" +) + +type AnonymousTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestAnonymous(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &AnonymousTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *AnonymousTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + // Create anonymous user + params := &SignupParams{ + Aud: ts.Config.JWT.Aud, + Provider: "anonymous", + } + u, err := params.ToUserModel(false) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new anonymous test user") +} + +func (ts *AnonymousTestSuite) TestAnonymousLogins() { + ts.Config.External.AnonymousUsers.Enabled = true + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "data": map[string]interface{}{ + "field": "foo", + }, + })) + + req := httptest.NewRequest(http.MethodPost, "/signup", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + assert.NotEmpty(ts.T(), data.User.ID) + assert.Equal(ts.T(), ts.Config.JWT.Aud, data.User.Aud) + assert.Empty(ts.T(), data.User.GetEmail()) + assert.Empty(ts.T(), data.User.GetPhone()) + assert.True(ts.T(), data.User.IsAnonymous) + assert.Equal(ts.T(), models.JSONMap(models.JSONMap{"field": "foo"}), data.User.UserMetaData) +} + +func (ts *AnonymousTestSuite) TestConvertAnonymousUserToPermanent() { + ts.Config.External.AnonymousUsers.Enabled = true + ts.Config.Sms.TestOTP = map[string]string{"1234567890": "000000", "1234560000": "000000"} + // test OTPs still require setting up an sms provider + ts.Config.Sms.Provider = "twilio" + ts.Config.Sms.Twilio.AccountSid = "fake-sid" + ts.Config.Sms.Twilio.AuthToken = "fake-token" + ts.Config.Sms.Twilio.MessageServiceSid = "fake-message-service-sid" + + cases := []struct { + desc string + body map[string]interface{} + verificationType string + }{ + { + desc: "convert anonymous user to permanent user with email", + body: map[string]interface{}{ + "email": "test@example.com", + }, + verificationType: "email_change", + }, + { + desc: "convert anonymous user to permanent user with phone", + body: map[string]interface{}{ + "phone": "1234567890", + }, + verificationType: "phone_change", + }, + { + desc: "convert anonymous user to permanent user with email & password", + body: map[string]interface{}{ + "email": "test2@example.com", + "password": "test-password", + }, + verificationType: "email_change", + }, + { + desc: "convert anonymous user to permanent user with phone & password", + body: map[string]interface{}{ + "phone": "1234560000", + "password": "test-password", + }, + verificationType: "phone_change", + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{})) + + req := httptest.NewRequest(http.MethodPost, "/signup", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + signupResponse := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&signupResponse)) + + // Add email to anonymous user + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + + req = httptest.NewRequest(http.MethodPut, "/user", &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", signupResponse.Token)) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + // Check if anonymous user is still anonymous + user, err := models.FindUserByID(ts.API.db, signupResponse.User.ID) + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), user) + require.True(ts.T(), user.IsAnonymous) + + // Check if user has a password set + if c.body["password"] != nil { + require.True(ts.T(), user.HasPassword()) + } + + switch c.verificationType { + case mail.EmailChangeVerification: + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "token_hash": user.EmailChangeTokenNew, + "type": c.verificationType, + })) + case phoneChangeVerification: + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "phone": user.PhoneChange, + "token": "000000", + "type": c.verificationType, + })) + } + + req = httptest.NewRequest(http.MethodPost, "/verify", &buffer) + req.Header.Set("Content-Type", "application/json") + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + // User is a permanent user and not anonymous anymore + assert.Equal(ts.T(), signupResponse.User.ID, data.User.ID) + assert.Equal(ts.T(), ts.Config.JWT.Aud, data.User.Aud) + assert.False(ts.T(), data.User.IsAnonymous) + + // User should have an identity + assert.Len(ts.T(), data.User.Identities, 1) + + switch c.verificationType { + case mail.EmailChangeVerification: + assert.Equal(ts.T(), c.body["email"], data.User.GetEmail()) + assert.Equal(ts.T(), models.JSONMap(models.JSONMap{"provider": "email", "providers": []interface{}{"email"}}), data.User.AppMetaData) + assert.NotEmpty(ts.T(), data.User.EmailConfirmedAt) + case phoneChangeVerification: + assert.Equal(ts.T(), c.body["phone"], data.User.GetPhone()) + assert.Equal(ts.T(), models.JSONMap(models.JSONMap{"provider": "phone", "providers": []interface{}{"phone"}}), data.User.AppMetaData) + assert.NotEmpty(ts.T(), data.User.PhoneConfirmedAt) + } + }) + } +} + +func (ts *AnonymousTestSuite) TestRateLimitAnonymousSignups() { + var buffer bytes.Buffer + ts.Config.External.AnonymousUsers.Enabled = true + + // It rate limits after 30 requests + for i := 0; i < int(ts.Config.RateLimitAnonymousUsers); i++ { + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{})) + req := httptest.NewRequest(http.MethodPost, "http://localhost/signup", &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("My-Custom-Header", "1.2.3.4") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) + } + + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{})) + req := httptest.NewRequest(http.MethodPost, "http://localhost/signup", &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("My-Custom-Header", "1.2.3.4") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusTooManyRequests, w.Code) + + // It ignores X-Forwarded-For by default + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{})) + req.Header.Set("X-Forwarded-For", "1.1.1.1") + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusTooManyRequests, w.Code) + + // It doesn't rate limit a new value for the limited header + req.Header.Set("My-Custom-Header", "5.6.7.8") + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) +} + +func (ts *AnonymousTestSuite) TestAdminUpdateAnonymousUser() { + claims := &AccessTokenClaims{ + Role: "supabase_admin", + } + adminJwt, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(ts.Config.JWT.Secret)) + require.NoError(ts.T(), err) + + u1, err := models.NewUser("", "", "", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err) + u1.IsAnonymous = true + require.NoError(ts.T(), ts.API.db.Create(u1)) + + u2, err := models.NewUser("", "", "", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err) + u2.IsAnonymous = true + require.NoError(ts.T(), ts.API.db.Create(u2)) + + cases := []struct { + desc string + userId uuid.UUID + body map[string]interface{} + expected map[string]interface{} + expectedIdentities int + }{ + { + desc: "update anonymous user with email and email confirm true", + userId: u1.ID, + body: map[string]interface{}{ + "email": "foo@example.com", + "email_confirm": true, + }, + expected: map[string]interface{}{ + "email": "foo@example.com", + "is_anonymous": false, + }, + expectedIdentities: 1, + }, + { + desc: "update anonymous user with email and email confirm false", + userId: u2.ID, + body: map[string]interface{}{ + "email": "bar@example.com", + "email_confirm": false, + }, + expected: map[string]interface{}{ + "email": "bar@example.com", + "is_anonymous": true, + }, + expectedIdentities: 1, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + + req := httptest.NewRequest(http.MethodPut, fmt.Sprintf("/admin/users/%s", c.userId), &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", adminJwt)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var data models.User + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + require.NotNil(ts.T(), data) + require.Len(ts.T(), data.Identities, c.expectedIdentities) + + actual := map[string]interface{}{ + "email": data.GetEmail(), + "is_anonymous": data.IsAnonymous, + } + + require.Equal(ts.T(), c.expected, actual) + }) + } +} diff --git a/auth_v2.169.0/internal/api/api.go b/auth_v2.169.0/internal/api/api.go new file mode 100644 index 0000000..aafcff2 --- /dev/null +++ b/auth_v2.169.0/internal/api/api.go @@ -0,0 +1,318 @@ +package api + +import ( + "net/http" + "regexp" + "time" + + "github.com/rs/cors" + "github.com/sebest/xff" + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/mailer" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" + "github.com/supabase/hibp" +) + +const ( + audHeaderName = "X-JWT-AUD" + defaultVersion = "unknown version" +) + +var bearerRegexp = regexp.MustCompile(`^(?:B|b)earer (\S+$)`) + +// API is the main REST API +type API struct { + handler http.Handler + db *storage.Connection + config *conf.GlobalConfiguration + version string + + hibpClient *hibp.PwnedClient + + // overrideTime can be used to override the clock used by handlers. Should only be used in tests! + overrideTime func() time.Time + + limiterOpts *LimiterOptions +} + +func (a *API) Now() time.Time { + if a.overrideTime != nil { + return a.overrideTime() + } + + return time.Now() +} + +// NewAPI instantiates a new REST API +func NewAPI(globalConfig *conf.GlobalConfiguration, db *storage.Connection, opt ...Option) *API { + return NewAPIWithVersion(globalConfig, db, defaultVersion, opt...) +} + +func (a *API) deprecationNotices() { + config := a.config + + log := logrus.WithField("component", "api") + + if config.JWT.AdminGroupName != "" { + log.Warn("DEPRECATION NOTICE: GOTRUE_JWT_ADMIN_GROUP_NAME not supported by Supabase's GoTrue, will be removed soon") + } + + if config.JWT.DefaultGroupName != "" { + log.Warn("DEPRECATION NOTICE: GOTRUE_JWT_DEFAULT_GROUP_NAME not supported by Supabase's GoTrue, will be removed soon") + } +} + +// NewAPIWithVersion creates a new REST API using the specified version +func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Connection, version string, opt ...Option) *API { + api := &API{config: globalConfig, db: db, version: version} + + for _, o := range opt { + o.apply(api) + } + if api.limiterOpts == nil { + api.limiterOpts = NewLimiterOptions(globalConfig) + } + if api.config.Password.HIBP.Enabled { + httpClient := &http.Client{ + // all HIBP API requests should finish quickly to avoid + // unnecessary slowdowns + Timeout: 5 * time.Second, + } + + api.hibpClient = &hibp.PwnedClient{ + UserAgent: api.config.Password.HIBP.UserAgent, + HTTP: httpClient, + } + + if api.config.Password.HIBP.Bloom.Enabled { + cache := utilities.NewHIBPBloomCache(api.config.Password.HIBP.Bloom.Items, api.config.Password.HIBP.Bloom.FalsePositives) + api.hibpClient.Cache = cache + + logrus.Infof("Pwned passwords cache is %.2f KB", float64(cache.Cap())/(8*1024.0)) + } + } + + api.deprecationNotices() + + xffmw, _ := xff.Default() + logger := observability.NewStructuredLogger(logrus.StandardLogger(), globalConfig) + + r := newRouter() + r.UseBypass(observability.AddRequestID(globalConfig)) + r.UseBypass(logger) + r.UseBypass(xffmw.Handler) + r.UseBypass(recoverer) + + if globalConfig.API.MaxRequestDuration > 0 { + r.UseBypass(timeoutMiddleware(globalConfig.API.MaxRequestDuration)) + } + + // request tracing should be added only when tracing or metrics is enabled + if globalConfig.Tracing.Enabled || globalConfig.Metrics.Enabled { + r.UseBypass(observability.RequestTracing()) + } + + if globalConfig.DB.CleanupEnabled { + cleanup := models.NewCleanup(globalConfig) + r.UseBypass(api.databaseCleanup(cleanup)) + } + + r.Get("/health", api.HealthCheck) + r.Get("/.well-known/jwks.json", api.Jwks) + + r.Route("/callback", func(r *router) { + r.Use(api.isValidExternalHost) + r.Use(api.loadFlowState) + + r.Get("/", api.ExternalProviderCallback) + r.Post("/", api.ExternalProviderCallback) + }) + + r.Route("/", func(r *router) { + r.Use(api.isValidExternalHost) + + r.Get("/settings", api.Settings) + + r.Get("/authorize", api.ExternalProviderRedirect) + + r.With(api.requireAdminCredentials).Post("/invite", api.Invite) + r.With(api.verifyCaptcha).Route("/signup", func(r *router) { + // rate limit per hour + limitAnonymousSignIns := api.limiterOpts.AnonymousSignIns + limitSignups := api.limiterOpts.Signups + r.Post("/", func(w http.ResponseWriter, r *http.Request) error { + params := &SignupParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + if params.Email == "" && params.Phone == "" { + if !api.config.External.AnonymousUsers.Enabled { + return unprocessableEntityError(ErrorCodeAnonymousProviderDisabled, "Anonymous sign-ins are disabled") + } + if _, err := api.limitHandler(limitAnonymousSignIns)(w, r); err != nil { + return err + } + return api.SignupAnonymously(w, r) + } + + // apply ip-based rate limiting on otps + if _, err := api.limitHandler(limitSignups)(w, r); err != nil { + return err + } + return api.Signup(w, r) + }) + }) + r.With(api.limitHandler(api.limiterOpts.Recover)). + With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover) + + r.With(api.limitHandler(api.limiterOpts.Resend)). + With(api.verifyCaptcha).Post("/resend", api.Resend) + + r.With(api.limitHandler(api.limiterOpts.MagicLink)). + With(api.verifyCaptcha).Post("/magiclink", api.MagicLink) + + r.With(api.limitHandler(api.limiterOpts.Otp)). + With(api.verifyCaptcha).Post("/otp", api.Otp) + + r.With(api.limitHandler(api.limiterOpts.Token)). + With(api.verifyCaptcha).Post("/token", api.Token) + + r.With(api.limitHandler(api.limiterOpts.Verify)).Route("/verify", func(r *router) { + r.Get("/", api.Verify) + r.Post("/", api.Verify) + }) + + r.With(api.requireAuthentication).Post("/logout", api.Logout) + + r.With(api.requireAuthentication).Route("/reauthenticate", func(r *router) { + r.Get("/", api.Reauthenticate) + }) + + r.With(api.requireAuthentication).Route("/user", func(r *router) { + r.Get("/", api.UserGet) + r.With(api.limitHandler(api.limiterOpts.User)).Put("/", api.UserUpdate) + + r.Route("/identities", func(r *router) { + r.Use(api.requireManualLinkingEnabled) + r.Get("/authorize", api.LinkIdentity) + r.Delete("/{identity_id}", api.DeleteIdentity) + }) + }) + + r.With(api.requireAuthentication).Route("/factors", func(r *router) { + r.Use(api.requireNotAnonymous) + r.Post("/", api.EnrollFactor) + r.Route("/{factor_id}", func(r *router) { + r.Use(api.loadFactor) + + r.With(api.limitHandler(api.limiterOpts.FactorVerify)). + Post("/verify", api.VerifyFactor) + r.With(api.limitHandler(api.limiterOpts.FactorChallenge)). + Post("/challenge", api.ChallengeFactor) + r.Delete("/", api.UnenrollFactor) + + }) + }) + + r.Route("/sso", func(r *router) { + r.Use(api.requireSAMLEnabled) + r.With(api.limitHandler(api.limiterOpts.SSO)). + With(api.verifyCaptcha).Post("/", api.SingleSignOn) + + r.Route("/saml", func(r *router) { + r.Get("/metadata", api.SAMLMetadata) + + r.With(api.limitHandler(api.limiterOpts.SAMLAssertion)). + Post("/acs", api.SamlAcs) + }) + }) + + r.Route("/admin", func(r *router) { + r.Use(api.requireAdminCredentials) + + r.Route("/audit", func(r *router) { + r.Get("/", api.adminAuditLog) + }) + + r.Route("/users", func(r *router) { + r.Get("/", api.adminUsers) + r.Post("/", api.adminUserCreate) + + r.Route("/{user_id}", func(r *router) { + r.Use(api.loadUser) + r.Route("/factors", func(r *router) { + r.Get("/", api.adminUserGetFactors) + r.Route("/{factor_id}", func(r *router) { + r.Use(api.loadFactor) + r.Delete("/", api.adminUserDeleteFactor) + r.Put("/", api.adminUserUpdateFactor) + }) + }) + + r.Get("/", api.adminUserGet) + r.Put("/", api.adminUserUpdate) + r.Delete("/", api.adminUserDelete) + }) + }) + + r.Post("/generate_link", api.adminGenerateLink) + + r.Route("/sso", func(r *router) { + r.Route("/providers", func(r *router) { + r.Get("/", api.adminSSOProvidersList) + r.Post("/", api.adminSSOProvidersCreate) + + r.Route("/{idp_id}", func(r *router) { + r.Use(api.loadSSOProvider) + + r.Get("/", api.adminSSOProvidersGet) + r.Put("/", api.adminSSOProvidersUpdate) + r.Delete("/", api.adminSSOProvidersDelete) + }) + }) + }) + + }) + }) + + corsHandler := cors.New(cors.Options{ + AllowedMethods: []string{http.MethodGet, http.MethodPost, http.MethodPut, http.MethodDelete}, + AllowedHeaders: globalConfig.CORS.AllAllowedHeaders([]string{"Accept", "Authorization", "Content-Type", "X-Client-IP", "X-Client-Info", audHeaderName, useCookieHeader, APIVersionHeaderName}), + ExposedHeaders: []string{"X-Total-Count", "Link", APIVersionHeaderName}, + AllowCredentials: true, + }) + + api.handler = corsHandler.Handler(r) + return api +} + +type HealthCheckResponse struct { + Version string `json:"version"` + Name string `json:"name"` + Description string `json:"description"` +} + +// HealthCheck endpoint indicates if the gotrue api service is available +func (a *API) HealthCheck(w http.ResponseWriter, r *http.Request) error { + return sendJSON(w, http.StatusOK, HealthCheckResponse{ + Version: a.version, + Name: "GoTrue", + Description: "GoTrue is a user registration and authentication API", + }) +} + +// Mailer returns NewMailer with the current tenant config +func (a *API) Mailer() mailer.Mailer { + config := a.config + return mailer.NewMailer(config) +} + +// ServeHTTP implements the http.Handler interface by passing the request along +// to its underlying Handler. +func (a *API) ServeHTTP(w http.ResponseWriter, r *http.Request) { + a.handler.ServeHTTP(w, r) +} diff --git a/auth_v2.169.0/internal/api/api_test.go b/auth_v2.169.0/internal/api/api_test.go new file mode 100644 index 0000000..a472be7 --- /dev/null +++ b/auth_v2.169.0/internal/api/api_test.go @@ -0,0 +1,57 @@ +package api + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/storage/test" +) + +const ( + apiTestVersion = "1" + apiTestConfig = "../../hack/test.env" +) + +func init() { + crypto.PasswordHashCost = crypto.QuickHashCost +} + +// setupAPIForTest creates a new API to run tests with. +// Using this function allows us to keep track of the database connection +// and cleaning up data between tests. +func setupAPIForTest() (*API, *conf.GlobalConfiguration, error) { + return setupAPIForTestWithCallback(nil) +} + +func setupAPIForTestWithCallback(cb func(*conf.GlobalConfiguration, *storage.Connection)) (*API, *conf.GlobalConfiguration, error) { + config, err := conf.LoadGlobal(apiTestConfig) + if err != nil { + return nil, nil, err + } + + if cb != nil { + cb(config, nil) + } + + conn, err := test.SetupDBConnection(config) + if err != nil { + return nil, nil, err + } + + if cb != nil { + cb(nil, conn) + } + + limiterOpts := NewLimiterOptions(config) + return NewAPIWithVersion(config, conn, apiTestVersion, limiterOpts), config, nil +} + +func TestEmailEnabledByDefault(t *testing.T) { + api, _, err := setupAPIForTest() + require.NoError(t, err) + + require.True(t, api.config.External.Email.Enabled) +} diff --git a/auth_v2.169.0/internal/api/apiversions.go b/auth_v2.169.0/internal/api/apiversions.go new file mode 100644 index 0000000..b5394a5 --- /dev/null +++ b/auth_v2.169.0/internal/api/apiversions.go @@ -0,0 +1,35 @@ +package api + +import ( + "time" +) + +const APIVersionHeaderName = "X-Supabase-Api-Version" + +type APIVersion = time.Time + +var ( + APIVersionInitial = time.Time{} + APIVersion20240101 = time.Date(2024, time.January, 1, 0, 0, 0, 0, time.UTC) +) + +func DetermineClosestAPIVersion(date string) (APIVersion, error) { + if date == "" { + return APIVersionInitial, nil + } + + parsed, err := time.ParseInLocation("2006-01-02", date, time.UTC) + if err != nil { + return APIVersionInitial, err + } + + if parsed.Compare(APIVersion20240101) >= 0 { + return APIVersion20240101, nil + } + + return APIVersionInitial, nil +} + +func FormatAPIVersion(apiVersion APIVersion) string { + return apiVersion.Format("2006-01-02") +} diff --git a/auth_v2.169.0/internal/api/apiversions_test.go b/auth_v2.169.0/internal/api/apiversions_test.go new file mode 100644 index 0000000..0a96221 --- /dev/null +++ b/auth_v2.169.0/internal/api/apiversions_test.go @@ -0,0 +1,29 @@ +package api + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDetermineClosestAPIVersion(t *testing.T) { + version, err := DetermineClosestAPIVersion("") + require.NoError(t, err) + require.Equal(t, APIVersionInitial, version) + + version, err = DetermineClosestAPIVersion("Not a date") + require.Error(t, err) + require.Equal(t, APIVersionInitial, version) + + version, err = DetermineClosestAPIVersion("2023-12-31") + require.NoError(t, err) + require.Equal(t, APIVersionInitial, version) + + version, err = DetermineClosestAPIVersion("2024-01-01") + require.NoError(t, err) + require.Equal(t, APIVersion20240101, version) + + version, err = DetermineClosestAPIVersion("2024-01-02") + require.NoError(t, err) + require.Equal(t, APIVersion20240101, version) +} diff --git a/auth_v2.169.0/internal/api/audit.go b/auth_v2.169.0/internal/api/audit.go new file mode 100644 index 0000000..351a7d2 --- /dev/null +++ b/auth_v2.169.0/internal/api/audit.go @@ -0,0 +1,47 @@ +package api + +import ( + "net/http" + "strings" + + "github.com/supabase/auth/internal/models" +) + +var filterColumnMap = map[string][]string{ + "author": {"actor_username", "actor_name"}, + "action": {"action"}, + "type": {"log_type"}, +} + +func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + // aud := a.requestAud(ctx, r) + pageParams, err := paginate(r) + if err != nil { + return badRequestError(ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err) + } + + var col []string + var qval string + q := r.URL.Query().Get("query") + if q != "" { + var exists bool + qparts := strings.SplitN(q, ":", 2) + col, exists = filterColumnMap[qparts[0]] + if !exists || len(qparts) < 2 { + return badRequestError(ErrorCodeValidationFailed, "Invalid query scope: %s", q) + } + qval = qparts[1] + } + + logs, err := models.FindAuditLogEntries(db, col, qval, pageParams) + if err != nil { + return internalServerError("Error searching for audit logs").WithInternalError(err) + } + + addPaginationHeaders(w, r, pageParams) + + return sendJSON(w, http.StatusOK, logs) +} diff --git a/auth_v2.169.0/internal/api/audit_test.go b/auth_v2.169.0/internal/api/audit_test.go new file mode 100644 index 0000000..c8e992e --- /dev/null +++ b/auth_v2.169.0/internal/api/audit_test.go @@ -0,0 +1,139 @@ +package api + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + jwt "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +type AuditTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration + + token string +} + +func TestAudit(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &AuditTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *AuditTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + ts.token = ts.makeSuperAdmin("") +} + +func (ts *AuditTestSuite) makeSuperAdmin(email string) string { + u, err := models.NewUser("", email, "test", ts.Config.JWT.Aud, map[string]interface{}{"full_name": "Test User"}) + require.NoError(ts.T(), err, "Error making new user") + + u.Role = "supabase_admin" + require.NoError(ts.T(), ts.API.db.Create(u)) + + session, err := models.NewSession(u.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(session)) + + var token string + + req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil) + token, _, err = ts.API.generateAccessToken(req, ts.API.db, u, &session.ID, models.PasswordGrant) + require.NoError(ts.T(), err, "Error generating access token") + + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.Parse(token, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + require.NoError(ts.T(), err, "Error parsing token") + + return token +} + +func (ts *AuditTestSuite) TestAuditGet() { + ts.prepareDeleteEvent() + // CHECK FOR AUDIT LOG + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/admin/audit", nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + assert.Equal(ts.T(), "; rel=\"last\"", w.Header().Get("Link")) + assert.Equal(ts.T(), "1", w.Header().Get("X-Total-Count")) + + logs := []models.AuditLogEntry{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&logs)) + + require.Len(ts.T(), logs, 1) + require.Contains(ts.T(), logs[0].Payload, "actor_username") + assert.Equal(ts.T(), "supabase_admin", logs[0].Payload["actor_username"]) + traits, ok := logs[0].Payload["traits"].(map[string]interface{}) + require.True(ts.T(), ok) + require.Contains(ts.T(), traits, "user_email") + assert.Equal(ts.T(), "test-delete@example.com", traits["user_email"]) +} + +func (ts *AuditTestSuite) TestAuditFilters() { + ts.prepareDeleteEvent() + + queries := []string{ + "/admin/audit?query=action:user_deleted", + "/admin/audit?query=type:team", + "/admin/audit?query=author:admin", + } + + for _, q := range queries { + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, q, nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + logs := []models.AuditLogEntry{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&logs)) + + require.Len(ts.T(), logs, 1) + require.Contains(ts.T(), logs[0].Payload, "actor_username") + assert.Equal(ts.T(), "supabase_admin", logs[0].Payload["actor_username"]) + traits, ok := logs[0].Payload["traits"].(map[string]interface{}) + require.True(ts.T(), ok) + require.Contains(ts.T(), traits, "user_email") + assert.Equal(ts.T(), "test-delete@example.com", traits["user_email"]) + } +} + +func (ts *AuditTestSuite) prepareDeleteEvent() { + // DELETE USER + u, err := models.NewUser("12345678", "test-delete@example.com", "test", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user") + + // Setup request + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/admin/users/%s", u.ID), nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) +} diff --git a/auth_v2.169.0/internal/api/auth.go b/auth_v2.169.0/internal/api/auth.go new file mode 100644 index 0000000..b03767f --- /dev/null +++ b/auth_v2.169.0/internal/api/auth.go @@ -0,0 +1,141 @@ +package api + +import ( + "context" + "fmt" + "net/http" + "strings" + + "github.com/gofrs/uuid" + jwt "github.com/golang-jwt/jwt/v5" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +// requireAuthentication checks incoming requests for tokens presented using the Authorization header +func (a *API) requireAuthentication(w http.ResponseWriter, r *http.Request) (context.Context, error) { + token, err := a.extractBearerToken(r) + if err != nil { + return nil, err + } + + ctx, err := a.parseJWTClaims(token, r) + if err != nil { + return ctx, err + } + + ctx, err = a.maybeLoadUserOrSession(ctx) + if err != nil { + return ctx, err + } + return ctx, err +} + +func (a *API) requireNotAnonymous(w http.ResponseWriter, r *http.Request) (context.Context, error) { + ctx := r.Context() + claims := getClaims(ctx) + if claims.IsAnonymous { + return nil, forbiddenError(ErrorCodeNoAuthorization, "Anonymous user not allowed to perform these actions") + } + return ctx, nil +} + +func (a *API) requireAdmin(ctx context.Context) (context.Context, error) { + // Find the administrative user + claims := getClaims(ctx) + if claims == nil { + return nil, forbiddenError(ErrorCodeBadJWT, "Invalid token") + } + + adminRoles := a.config.JWT.AdminRoles + + if isStringInSlice(claims.Role, adminRoles) { + // successful authentication + return withAdminUser(ctx, &models.User{Role: claims.Role, Email: storage.NullString(claims.Role)}), nil + } + + return nil, forbiddenError(ErrorCodeNotAdmin, "User not allowed").WithInternalMessage(fmt.Sprintf("this token needs to have one of the following roles: %v", strings.Join(adminRoles, ", "))) +} + +func (a *API) extractBearerToken(r *http.Request) (string, error) { + authHeader := r.Header.Get("Authorization") + matches := bearerRegexp.FindStringSubmatch(authHeader) + if len(matches) != 2 { + return "", httpError(http.StatusUnauthorized, ErrorCodeNoAuthorization, "This endpoint requires a Bearer token") + } + + return matches[1], nil +} + +func (a *API) parseJWTClaims(bearer string, r *http.Request) (context.Context, error) { + ctx := r.Context() + config := a.config + + p := jwt.NewParser(jwt.WithValidMethods(config.JWT.ValidMethods)) + token, err := p.ParseWithClaims(bearer, &AccessTokenClaims{}, func(token *jwt.Token) (interface{}, error) { + if kid, ok := token.Header["kid"]; ok { + if kidStr, ok := kid.(string); ok { + return conf.FindPublicKeyByKid(kidStr, &config.JWT) + } + } + if alg, ok := token.Header["alg"]; ok { + if alg == jwt.SigningMethodHS256.Name { + // preserve backward compatibility for cases where the kid is not set + return []byte(config.JWT.Secret), nil + } + } + return nil, fmt.Errorf("missing kid") + }) + if err != nil { + return nil, forbiddenError(ErrorCodeBadJWT, "invalid JWT: unable to parse or verify signature, %v", err).WithInternalError(err) + } + + return withToken(ctx, token), nil +} + +func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, error) { + db := a.db.WithContext(ctx) + claims := getClaims(ctx) + + if claims == nil { + return ctx, forbiddenError(ErrorCodeBadJWT, "invalid token: missing claims") + } + + if claims.Subject == "" { + return nil, forbiddenError(ErrorCodeBadJWT, "invalid claim: missing sub claim") + } + + var user *models.User + if claims.Subject != "" { + userId, err := uuid.FromString(claims.Subject) + if err != nil { + return ctx, badRequestError(ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID").WithInternalError(err) + } + user, err = models.FindUserByID(db, userId) + if err != nil { + if models.IsNotFoundError(err) { + return ctx, forbiddenError(ErrorCodeUserNotFound, "User from sub claim in JWT does not exist") + } + return ctx, err + } + ctx = withUser(ctx, user) + } + + var session *models.Session + if claims.SessionId != "" && claims.SessionId != uuid.Nil.String() { + sessionId, err := uuid.FromString(claims.SessionId) + if err != nil { + return ctx, forbiddenError(ErrorCodeBadJWT, "invalid claim: session_id claim must be a UUID").WithInternalError(err) + } + session, err = models.FindSessionByID(db, sessionId, false) + if err != nil { + if models.IsNotFoundError(err) { + return ctx, forbiddenError(ErrorCodeSessionNotFound, "Session from session_id claim in JWT does not exist").WithInternalError(err).WithInternalMessage(fmt.Sprintf("session id (%s) doesn't exist", sessionId)) + } + return ctx, err + } + ctx = withSession(ctx, session) + } + return ctx, nil +} diff --git a/auth_v2.169.0/internal/api/auth_test.go b/auth_v2.169.0/internal/api/auth_test.go new file mode 100644 index 0000000..71afe66 --- /dev/null +++ b/auth_v2.169.0/internal/api/auth_test.go @@ -0,0 +1,284 @@ +package api + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gofrs/uuid" + jwt "github.com/golang-jwt/jwt/v5" + jwk "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +type AuthTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestAuth(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &AuthTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + suite.Run(t, ts) +} + +func (ts *AuthTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + // Create user + u, err := models.NewUser("", "test@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") +} + +func (ts *AuthTestSuite) TestExtractBearerToken() { + userClaims := &AccessTokenClaims{ + Role: "authenticated", + } + userJwt, err := jwt.NewWithClaims(jwt.SigningMethodHS256, userClaims).SignedString([]byte(ts.Config.JWT.Secret)) + require.NoError(ts.T(), err) + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Set("Authorization", "Bearer "+userJwt) + + token, err := ts.API.extractBearerToken(req) + require.NoError(ts.T(), err) + require.Equal(ts.T(), userJwt, token) +} + +func (ts *AuthTestSuite) TestParseJWTClaims() { + cases := []struct { + desc string + key map[string]interface{} + }{ + { + desc: "HMAC key", + key: map[string]interface{}{ + "kty": "oct", + "k": "S1LgKUjeqXDEolv9WPtjUpADVMHU_KYu8uRDrM-pDGg", + "kid": "ac50c3cc-9cf7-4fd6-a11f-fe066fd39118", + "key_ops": []string{"sign", "verify"}, + "alg": "HS256", + }, + }, + { + desc: "RSA key", + key: map[string]interface{}{ + "kty": "RSA", + "n": "2g0B_hMIx5ZPuTUtLRpRr0k314XniYm3AUFgR5FmTZIjrn7vLwsWij-2egGZeHa-y9ypAgB9Q-lQ3AlT7RMPiCIyLQI6TTC8k10NEnj8c0QZwENx1Qr8aBbuZbOP9Cz30EMWZSbzMbz7r8-3rp5wBRBtIPnLlbfZh_p0iBaJfB77-r_mvhOIFM4xS7ef3nkE96dnvbEN5a-HfjzDJIAt-LniUvzMWW2gQcmHiM4oeijE3PHesapLMt2JpsMhSRo8L7tysags9VMoyZ1GnpCdjtRwb_KpY9QTjV6lL8G5nsKFH7bhABYcpjDOvqkfT5nPXj6C7oCo6MPRirPWUTbq2w", + "e": "AQAB", + "d": "OOTj_DNjOxCRRLYHT5lqbt4f3_BkdZKlWYKBaKsbkmnrPYCJUDEIdJIjPrpkHPZ-2hp9TrRp-upJ2t_kMhujFdY2WWAXbkSlL5475vICjODcBzqR3RC8wzwYgBjWGtQQ5RpcIZCELBovYbRFLR7SA8BBeTU0VaBe9gf3l_qpbOT9QIl268uFdWndTjpehGLQRmAtR1snhvTha0b9nsBZsM_K-EfnoF7Q_lPsjwWDvIGpFXao8Ifaa_sFtQkHjHVBMW2Qgx3ZSrEva_brk7w0MNSYI7Nsmr56xFOpFRwZy0v8ZtgQZ4hXmUInRHIoQ2APeds9YmemojvJKVflt9pLIQ", + "p": "-o2hdQ5Z35cIS5APTVULj_BMoPJpgkuX-PSYC1SeBeff9K04kG5zrFMWJy_-27-ys4q754lpNwJdX2CjN1nb6qyn-uKP8B2oLayKs9ebkiOqvm3S2Xblvi_F8x6sOLba3lTYHK8G7U9aMB9U0mhAzzMFdw15XXusVFDvk-zxL28", + "q": "3sp-7HzZE_elKRmebjivcDhkXO2GrcN3EIqYbbXssHZFXJwVE9oc2CErGWa7QetOCr9C--ZuTmX0X3L--CoYr-hMB0dN8lcAhapr3aau-4i7vE3DWSUdcFSyi0BBDg8pWQWbxNyTXBuWeh1cnRBsLjCxAOVTF0y3_BnVR7mbBVU", + "dp": "DuYHGMfOrk3zz1J0pnuNIXT_iX6AqZ_HHKWmuN3CO8Wq-oimWWhH9pJGOfRPqk9-19BDFiSEniHE3ZwIeI0eV5kGsBNyzatlybl90e3bMVhvmb08EXRRevqqQaesQ_8Tiq7u3t3Fgqz6RuxGBfDvEaMOCyNA-T8WYzkg1eH8AX8", + "dq": "opOCK3CvuDJvA57-TdBvtaRxGJ78OLD6oceBlA29useTthDwEJyJj-4kVVTyMRhUyuLnLoro06zytvRjuxR9D2CkmmseJkn2x5OlQwnvhv4wgSj99H9xDBfCcntg_bFyqtO859tObVh0ZogmnTbuuoYtpEm0aLxDRmRTjxOSXEE", + "qi": "8skVE7BDASHXytKSWYbkxD0B3WpXic2rtnLgiMgasdSxul8XwcB-vjVSZprVrxkcmm6ZhszoxOlq8yylBmMvAnG_gEzTls_xapeuEXGYiGaTcpkCt1r-tBKcQkka2SayaWwAljsX4xSw-zKP2koUkEET_tIcbBOW1R4OWfRGqOI", + "kid": "0d24b26c-b3ec-4c02-acfd-d5a54d50b3a4", + "key_ops": []string{"sign", "verify"}, + "alg": "RS256", + }, + }, + { + desc: "EC key", + key: map[string]interface{}{ + "kty": "EC", + "x": "5wsOh-DrNPpm9KkuydtgGs_cv3oNvtR9OdXywt12aS4", + "y": "0y01ZbuH_VQjMEd8fcYaLdiv25EVJ5GOrb79dJJsqrM", + "crv": "P-256", + "d": "EDP4ReMMpAUcf82EF3JYvkm8C5hVAh258Rj6f3HTx7c", + "kid": "10646a77-f470-44a8-8400-2f988d9c9c1a", + "key_ops": []string{"sign", "verify"}, + "alg": "ES256", + }, + }, + { + desc: "Ed25519 key", + key: map[string]interface{}{ + "crv": "Ed25519", + "d": "jVpCLvOxatVkKe1MW9nFRn6Q8VVZPq5yziKU_Z0Yu-c", + "x": "YDkGdufJBQEPO6ylvd9IKfZlzvm9tOG5VCDpkJSSkiA", + "kty": "OKP", + "kid": "ec5e7a96-ea66-456c-826c-d8d6cb928c0f", + "key_ops": []string{"sign", "verify"}, + "alg": "EdDSA", + }, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + bytes, err := json.Marshal(c.key) + require.NoError(ts.T(), err) + privKey, err := jwk.ParseKey(bytes) + require.NoError(ts.T(), err) + pubKey, err := privKey.PublicKey() + require.NoError(ts.T(), err) + ts.Config.JWT.Keys = conf.JwtKeysDecoder{privKey.KeyID(): conf.JwkInfo{ + PublicKey: pubKey, + PrivateKey: privKey, + }} + ts.Config.JWT.ValidMethods = nil + require.NoError(ts.T(), ts.Config.ApplyDefaults()) + + userClaims := &AccessTokenClaims{ + Role: "authenticated", + } + + // get signing key and method from config + jwk, err := conf.GetSigningJwk(&ts.Config.JWT) + require.NoError(ts.T(), err) + signingMethod := conf.GetSigningAlg(jwk) + signingKey, err := conf.GetSigningKey(jwk) + require.NoError(ts.T(), err) + + userJwtToken := jwt.NewWithClaims(signingMethod, userClaims) + require.NoError(ts.T(), err) + userJwtToken.Header["kid"] = jwk.KeyID() + userJwt, err := userJwtToken.SignedString(signingKey) + require.NoError(ts.T(), err) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Set("Authorization", "Bearer "+userJwt) + ctx, err := ts.API.parseJWTClaims(userJwt, req) + require.NoError(ts.T(), err) + + // check if token is stored in context + token := getToken(ctx) + require.Equal(ts.T(), userJwt, token.Raw) + }) + } +} + +func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() { + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + s, err := models.NewSession(u.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(s)) + + require.NoError(ts.T(), ts.API.db.Load(s)) + + cases := []struct { + Desc string + UserJwtClaims *AccessTokenClaims + ExpectedError error + ExpectedUser *models.User + ExpectedSession *models.Session + }{ + { + Desc: "Missing Subject Claim", + UserJwtClaims: &AccessTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: "", + }, + Role: "authenticated", + }, + ExpectedError: forbiddenError(ErrorCodeBadJWT, "invalid claim: missing sub claim"), + ExpectedUser: nil, + }, + { + Desc: "Valid Subject Claim", + UserJwtClaims: &AccessTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: u.ID.String(), + }, + Role: "authenticated", + }, + ExpectedError: nil, + ExpectedUser: u, + }, + { + Desc: "Invalid Subject Claim", + UserJwtClaims: &AccessTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: "invalid-subject-claim", + }, + Role: "authenticated", + }, + ExpectedError: badRequestError(ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID"), + ExpectedUser: nil, + }, + { + Desc: "Empty Session ID Claim", + UserJwtClaims: &AccessTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: u.ID.String(), + }, + Role: "authenticated", + SessionId: "", + }, + ExpectedError: nil, + ExpectedUser: u, + }, + { + Desc: "Invalid Session ID Claim", + UserJwtClaims: &AccessTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: u.ID.String(), + }, + Role: "authenticated", + SessionId: uuid.Nil.String(), + }, + ExpectedError: nil, + ExpectedUser: u, + }, + { + Desc: "Valid Session ID Claim", + UserJwtClaims: &AccessTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: u.ID.String(), + }, + Role: "authenticated", + SessionId: s.ID.String(), + }, + ExpectedError: nil, + ExpectedUser: u, + ExpectedSession: s, + }, + { + Desc: "Session ID doesn't exist", + UserJwtClaims: &AccessTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: u.ID.String(), + }, + Role: "authenticated", + SessionId: "73bf9ee0-9e8c-453b-b484-09cb93e2f341", + }, + ExpectedError: forbiddenError(ErrorCodeSessionNotFound, "Session from session_id claim in JWT does not exist").WithInternalError(models.SessionNotFoundError{}).WithInternalMessage("session id (73bf9ee0-9e8c-453b-b484-09cb93e2f341) doesn't exist"), + ExpectedUser: u, + ExpectedSession: nil, + }, + } + + for _, c := range cases { + ts.Run(c.Desc, func() { + userJwt, err := jwt.NewWithClaims(jwt.SigningMethodHS256, c.UserJwtClaims).SignedString([]byte(ts.Config.JWT.Secret)) + require.NoError(ts.T(), err) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Set("Authorization", "Bearer "+userJwt) + + ctx, err := ts.API.parseJWTClaims(userJwt, req) + require.NoError(ts.T(), err) + ctx, err = ts.API.maybeLoadUserOrSession(ctx) + if c.ExpectedError != nil { + require.Equal(ts.T(), c.ExpectedError.Error(), err.Error()) + } else { + require.Equal(ts.T(), c.ExpectedError, err) + } + require.Equal(ts.T(), c.ExpectedUser, getUser(ctx)) + require.Equal(ts.T(), c.ExpectedSession, getSession(ctx)) + }) + } +} diff --git a/auth_v2.169.0/internal/api/context.go b/auth_v2.169.0/internal/api/context.go new file mode 100644 index 0000000..3047f3d --- /dev/null +++ b/auth_v2.169.0/internal/api/context.go @@ -0,0 +1,243 @@ +package api + +import ( + "context" + "net/url" + + jwt "github.com/golang-jwt/jwt/v5" + "github.com/supabase/auth/internal/models" +) + +type contextKey string + +func (c contextKey) String() string { + return "gotrue api context key " + string(c) +} + +const ( + tokenKey = contextKey("jwt") + inviteTokenKey = contextKey("invite_token") + signatureKey = contextKey("signature") + externalProviderTypeKey = contextKey("external_provider_type") + userKey = contextKey("user") + targetUserKey = contextKey("target_user") + factorKey = contextKey("factor") + sessionKey = contextKey("session") + externalReferrerKey = contextKey("external_referrer") + functionHooksKey = contextKey("function_hooks") + adminUserKey = contextKey("admin_user") + oauthTokenKey = contextKey("oauth_token") // for OAuth1.0, also known as request token + oauthVerifierKey = contextKey("oauth_verifier") + ssoProviderKey = contextKey("sso_provider") + externalHostKey = contextKey("external_host") + flowStateKey = contextKey("flow_state_id") +) + +// withToken adds the JWT token to the context. +func withToken(ctx context.Context, token *jwt.Token) context.Context { + return context.WithValue(ctx, tokenKey, token) +} + +// getToken reads the JWT token from the context. +func getToken(ctx context.Context) *jwt.Token { + obj := ctx.Value(tokenKey) + if obj == nil { + return nil + } + + return obj.(*jwt.Token) +} + +func getClaims(ctx context.Context) *AccessTokenClaims { + token := getToken(ctx) + if token == nil { + return nil + } + return token.Claims.(*AccessTokenClaims) +} + +// withUser adds the user to the context. +func withUser(ctx context.Context, u *models.User) context.Context { + return context.WithValue(ctx, userKey, u) +} + +// withTargetUser adds the target user for linking to the context. +func withTargetUser(ctx context.Context, u *models.User) context.Context { + return context.WithValue(ctx, targetUserKey, u) +} + +// with Factor adds the factor id to the context. +func withFactor(ctx context.Context, f *models.Factor) context.Context { + return context.WithValue(ctx, factorKey, f) +} + +// getUser reads the user from the context. +func getUser(ctx context.Context) *models.User { + if ctx == nil { + return nil + } + obj := ctx.Value(userKey) + if obj == nil { + return nil + } + return obj.(*models.User) +} + +// getTargetUser reads the user from the context. +func getTargetUser(ctx context.Context) *models.User { + if ctx == nil { + return nil + } + obj := ctx.Value(targetUserKey) + if obj == nil { + return nil + } + return obj.(*models.User) +} + +// getFactor reads the factor id from the context +func getFactor(ctx context.Context) *models.Factor { + obj := ctx.Value(factorKey) + if obj == nil { + return nil + } + return obj.(*models.Factor) +} + +// withSession adds the session to the context. +func withSession(ctx context.Context, s *models.Session) context.Context { + return context.WithValue(ctx, sessionKey, s) +} + +// getSession reads the session from the context. +func getSession(ctx context.Context) *models.Session { + if ctx == nil { + return nil + } + obj := ctx.Value(sessionKey) + if obj == nil { + return nil + } + return obj.(*models.Session) +} + +// withSignature adds the provided request ID to the context. +func withSignature(ctx context.Context, id string) context.Context { + return context.WithValue(ctx, signatureKey, id) +} + +func withInviteToken(ctx context.Context, token string) context.Context { + return context.WithValue(ctx, inviteTokenKey, token) +} + +func withFlowStateID(ctx context.Context, FlowStateID string) context.Context { + return context.WithValue(ctx, flowStateKey, FlowStateID) +} + +func getFlowStateID(ctx context.Context) string { + obj := ctx.Value(flowStateKey) + if obj == nil { + return "" + } + return obj.(string) +} + +func getInviteToken(ctx context.Context) string { + obj := ctx.Value(inviteTokenKey) + if obj == nil { + return "" + } + + return obj.(string) +} + +// withExternalProviderType adds the provided request ID to the context. +func withExternalProviderType(ctx context.Context, id string) context.Context { + return context.WithValue(ctx, externalProviderTypeKey, id) +} + +// getExternalProviderType reads the request ID from the context. +func getExternalProviderType(ctx context.Context) string { + obj := ctx.Value(externalProviderTypeKey) + if obj == nil { + return "" + } + + return obj.(string) +} + +func withExternalReferrer(ctx context.Context, token string) context.Context { + return context.WithValue(ctx, externalReferrerKey, token) +} + +func getExternalReferrer(ctx context.Context) string { + obj := ctx.Value(externalReferrerKey) + if obj == nil { + return "" + } + + return obj.(string) +} + +// withAdminUser adds the admin user to the context. +func withAdminUser(ctx context.Context, u *models.User) context.Context { + return context.WithValue(ctx, adminUserKey, u) +} + +// getAdminUser reads the admin user from the context. +func getAdminUser(ctx context.Context) *models.User { + obj := ctx.Value(adminUserKey) + if obj == nil { + return nil + } + return obj.(*models.User) +} + +// withRequestToken adds the request token to the context +func withRequestToken(ctx context.Context, token string) context.Context { + return context.WithValue(ctx, oauthTokenKey, token) +} + +func getRequestToken(ctx context.Context) string { + obj := ctx.Value(oauthTokenKey) + if obj == nil { + return "" + } + return obj.(string) +} + +func withOAuthVerifier(ctx context.Context, token string) context.Context { + return context.WithValue(ctx, oauthVerifierKey, token) +} + +func getOAuthVerifier(ctx context.Context) string { + obj := ctx.Value(oauthVerifierKey) + if obj == nil { + return "" + } + return obj.(string) +} + +func withSSOProvider(ctx context.Context, provider *models.SSOProvider) context.Context { + return context.WithValue(ctx, ssoProviderKey, provider) +} + +func getSSOProvider(ctx context.Context) *models.SSOProvider { + obj := ctx.Value(ssoProviderKey) + if obj == nil { + return nil + } + return obj.(*models.SSOProvider) +} + +func withExternalHost(ctx context.Context, u *url.URL) context.Context { + return context.WithValue(ctx, externalHostKey, u) +} + +func getExternalHost(ctx context.Context) *url.URL { + obj := ctx.Value(externalHostKey) + if obj == nil { + return nil + } + return obj.(*url.URL) +} diff --git a/auth_v2.169.0/internal/api/errorcodes.go b/auth_v2.169.0/internal/api/errorcodes.go new file mode 100644 index 0000000..8f09901 --- /dev/null +++ b/auth_v2.169.0/internal/api/errorcodes.go @@ -0,0 +1,95 @@ +package api + +type ErrorCode = string + +const ( + // ErrorCodeUnknown should not be used directly, it only indicates a failure in the error handling system in such a way that an error code was not assigned properly. + ErrorCodeUnknown ErrorCode = "unknown" + + // ErrorCodeUnexpectedFailure signals an unexpected failure such as a 500 Internal Server Error. + ErrorCodeUnexpectedFailure ErrorCode = "unexpected_failure" + + ErrorCodeValidationFailed ErrorCode = "validation_failed" + ErrorCodeBadJSON ErrorCode = "bad_json" + ErrorCodeEmailExists ErrorCode = "email_exists" + ErrorCodePhoneExists ErrorCode = "phone_exists" + ErrorCodeBadJWT ErrorCode = "bad_jwt" + ErrorCodeNotAdmin ErrorCode = "not_admin" + ErrorCodeNoAuthorization ErrorCode = "no_authorization" + ErrorCodeUserNotFound ErrorCode = "user_not_found" + ErrorCodeSessionNotFound ErrorCode = "session_not_found" + ErrorCodeSessionExpired ErrorCode = "session_expired" + ErrorCodeRefreshTokenNotFound ErrorCode = "refresh_token_not_found" + ErrorCodeRefreshTokenAlreadyUsed ErrorCode = "refresh_token_already_used" + ErrorCodeFlowStateNotFound ErrorCode = "flow_state_not_found" + ErrorCodeFlowStateExpired ErrorCode = "flow_state_expired" + ErrorCodeSignupDisabled ErrorCode = "signup_disabled" + ErrorCodeUserBanned ErrorCode = "user_banned" + ErrorCodeProviderEmailNeedsVerification ErrorCode = "provider_email_needs_verification" + ErrorCodeInviteNotFound ErrorCode = "invite_not_found" + ErrorCodeBadOAuthState ErrorCode = "bad_oauth_state" + ErrorCodeBadOAuthCallback ErrorCode = "bad_oauth_callback" + ErrorCodeOAuthProviderNotSupported ErrorCode = "oauth_provider_not_supported" + ErrorCodeUnexpectedAudience ErrorCode = "unexpected_audience" + ErrorCodeSingleIdentityNotDeletable ErrorCode = "single_identity_not_deletable" + ErrorCodeEmailConflictIdentityNotDeletable ErrorCode = "email_conflict_identity_not_deletable" + ErrorCodeIdentityAlreadyExists ErrorCode = "identity_already_exists" + ErrorCodeEmailProviderDisabled ErrorCode = "email_provider_disabled" + ErrorCodePhoneProviderDisabled ErrorCode = "phone_provider_disabled" + ErrorCodeTooManyEnrolledMFAFactors ErrorCode = "too_many_enrolled_mfa_factors" + ErrorCodeMFAFactorNameConflict ErrorCode = "mfa_factor_name_conflict" + ErrorCodeMFAFactorNotFound ErrorCode = "mfa_factor_not_found" + ErrorCodeMFAIPAddressMismatch ErrorCode = "mfa_ip_address_mismatch" + ErrorCodeMFAChallengeExpired ErrorCode = "mfa_challenge_expired" + ErrorCodeMFAVerificationFailed ErrorCode = "mfa_verification_failed" + ErrorCodeMFAVerificationRejected ErrorCode = "mfa_verification_rejected" + ErrorCodeInsufficientAAL ErrorCode = "insufficient_aal" + ErrorCodeCaptchaFailed ErrorCode = "captcha_failed" + ErrorCodeSAMLProviderDisabled ErrorCode = "saml_provider_disabled" + ErrorCodeManualLinkingDisabled ErrorCode = "manual_linking_disabled" + ErrorCodeSMSSendFailed ErrorCode = "sms_send_failed" + ErrorCodeEmailNotConfirmed ErrorCode = "email_not_confirmed" + ErrorCodePhoneNotConfirmed ErrorCode = "phone_not_confirmed" + ErrorCodeSAMLRelayStateNotFound ErrorCode = "saml_relay_state_not_found" + ErrorCodeSAMLRelayStateExpired ErrorCode = "saml_relay_state_expired" + ErrorCodeSAMLIdPNotFound ErrorCode = "saml_idp_not_found" + ErrorCodeSAMLAssertionNoUserID ErrorCode = "saml_assertion_no_user_id" + ErrorCodeSAMLAssertionNoEmail ErrorCode = "saml_assertion_no_email" + ErrorCodeUserAlreadyExists ErrorCode = "user_already_exists" + ErrorCodeSSOProviderNotFound ErrorCode = "sso_provider_not_found" + ErrorCodeSAMLMetadataFetchFailed ErrorCode = "saml_metadata_fetch_failed" + ErrorCodeSAMLIdPAlreadyExists ErrorCode = "saml_idp_already_exists" + ErrorCodeSSODomainAlreadyExists ErrorCode = "sso_domain_already_exists" + ErrorCodeSAMLEntityIDMismatch ErrorCode = "saml_entity_id_mismatch" + ErrorCodeConflict ErrorCode = "conflict" + ErrorCodeProviderDisabled ErrorCode = "provider_disabled" + ErrorCodeUserSSOManaged ErrorCode = "user_sso_managed" + ErrorCodeReauthenticationNeeded ErrorCode = "reauthentication_needed" + ErrorCodeSamePassword ErrorCode = "same_password" + ErrorCodeReauthenticationNotValid ErrorCode = "reauthentication_not_valid" + ErrorCodeOTPExpired ErrorCode = "otp_expired" + ErrorCodeOTPDisabled ErrorCode = "otp_disabled" + ErrorCodeIdentityNotFound ErrorCode = "identity_not_found" + ErrorCodeWeakPassword ErrorCode = "weak_password" + ErrorCodeOverRequestRateLimit ErrorCode = "over_request_rate_limit" + ErrorCodeOverEmailSendRateLimit ErrorCode = "over_email_send_rate_limit" + ErrorCodeOverSMSSendRateLimit ErrorCode = "over_sms_send_rate_limit" + ErrorCodeBadCodeVerifier ErrorCode = "bad_code_verifier" + ErrorCodeAnonymousProviderDisabled ErrorCode = "anonymous_provider_disabled" + ErrorCodeHookTimeout ErrorCode = "hook_timeout" + ErrorCodeHookTimeoutAfterRetry ErrorCode = "hook_timeout_after_retry" + ErrorCodeHookPayloadOverSizeLimit ErrorCode = "hook_payload_over_size_limit" + ErrorCodeHookPayloadInvalidContentType ErrorCode = "hook_payload_invalid_content_type" + ErrorCodeRequestTimeout ErrorCode = "request_timeout" + ErrorCodeMFAPhoneEnrollDisabled ErrorCode = "mfa_phone_enroll_not_enabled" + ErrorCodeMFAPhoneVerifyDisabled ErrorCode = "mfa_phone_verify_not_enabled" + ErrorCodeMFATOTPEnrollDisabled ErrorCode = "mfa_totp_enroll_not_enabled" + ErrorCodeMFATOTPVerifyDisabled ErrorCode = "mfa_totp_verify_not_enabled" + ErrorCodeMFAWebAuthnEnrollDisabled ErrorCode = "mfa_webauthn_enroll_not_enabled" + ErrorCodeMFAWebAuthnVerifyDisabled ErrorCode = "mfa_webauthn_verify_not_enabled" + ErrorCodeMFAVerifiedFactorExists ErrorCode = "mfa_verified_factor_exists" + //#nosec G101 -- Not a secret value. + ErrorCodeInvalidCredentials ErrorCode = "invalid_credentials" + ErrorCodeEmailAddressNotAuthorized ErrorCode = "email_address_not_authorized" + ErrorCodeEmailAddressInvalid ErrorCode = "email_address_invalid" +) diff --git a/auth_v2.169.0/internal/api/errors.go b/auth_v2.169.0/internal/api/errors.go new file mode 100644 index 0000000..0ce9512 --- /dev/null +++ b/auth_v2.169.0/internal/api/errors.go @@ -0,0 +1,330 @@ +package api + +import ( + "context" + "fmt" + "net/http" + "os" + "runtime/debug" + "time" + + "github.com/pkg/errors" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/utilities" +) + +// Common error messages during signup flow +var ( + DuplicateEmailMsg = "A user with this email address has already been registered" + DuplicatePhoneMsg = "A user with this phone number has already been registered" + UserExistsError error = errors.New("user already exists") +) + +const InvalidChannelError = "Invalid channel, supported values are 'sms' or 'whatsapp'. 'whatsapp' is only supported if Twilio or Twilio Verify is used as the provider." + +var oauthErrorMap = map[int]string{ + http.StatusBadRequest: "invalid_request", + http.StatusUnauthorized: "unauthorized_client", + http.StatusForbidden: "access_denied", + http.StatusInternalServerError: "server_error", + http.StatusServiceUnavailable: "temporarily_unavailable", +} + +// OAuthError is the JSON handler for OAuth2 error responses +type OAuthError struct { + Err string `json:"error"` + Description string `json:"error_description,omitempty"` + InternalError error `json:"-"` + InternalMessage string `json:"-"` +} + +func (e *OAuthError) Error() string { + if e.InternalMessage != "" { + return e.InternalMessage + } + return fmt.Sprintf("%s: %s", e.Err, e.Description) +} + +// WithInternalError adds internal error information to the error +func (e *OAuthError) WithInternalError(err error) *OAuthError { + e.InternalError = err + return e +} + +// WithInternalMessage adds internal message information to the error +func (e *OAuthError) WithInternalMessage(fmtString string, args ...interface{}) *OAuthError { + e.InternalMessage = fmt.Sprintf(fmtString, args...) + return e +} + +// Cause returns the root cause error +func (e *OAuthError) Cause() error { + if e.InternalError != nil { + return e.InternalError + } + return e +} + +func oauthError(err string, description string) *OAuthError { + return &OAuthError{Err: err, Description: description} +} + +func badRequestError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusBadRequest, errorCode, fmtString, args...) +} + +func internalServerError(fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusInternalServerError, ErrorCodeUnexpectedFailure, fmtString, args...) +} + +func notFoundError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusNotFound, errorCode, fmtString, args...) +} + +func forbiddenError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusForbidden, errorCode, fmtString, args...) +} + +func unprocessableEntityError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusUnprocessableEntity, errorCode, fmtString, args...) +} + +func tooManyRequestsError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusTooManyRequests, errorCode, fmtString, args...) +} + +func conflictError(fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusConflict, ErrorCodeConflict, fmtString, args...) +} + +// HTTPError is an error with a message and an HTTP status code. +type HTTPError struct { + HTTPStatus int `json:"code"` // do not rename the JSON tags! + ErrorCode string `json:"error_code,omitempty"` // do not rename the JSON tags! + Message string `json:"msg"` // do not rename the JSON tags! + InternalError error `json:"-"` + InternalMessage string `json:"-"` + ErrorID string `json:"error_id,omitempty"` +} + +func (e *HTTPError) Error() string { + if e.InternalMessage != "" { + return e.InternalMessage + } + return fmt.Sprintf("%d: %s", e.HTTPStatus, e.Message) +} + +func (e *HTTPError) Is(target error) bool { + return e.Error() == target.Error() +} + +// Cause returns the root cause error +func (e *HTTPError) Cause() error { + if e.InternalError != nil { + return e.InternalError + } + return e +} + +// WithInternalError adds internal error information to the error +func (e *HTTPError) WithInternalError(err error) *HTTPError { + e.InternalError = err + return e +} + +// WithInternalMessage adds internal message information to the error +func (e *HTTPError) WithInternalMessage(fmtString string, args ...interface{}) *HTTPError { + e.InternalMessage = fmt.Sprintf(fmtString, args...) + return e +} + +func httpError(httpStatus int, errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return &HTTPError{ + HTTPStatus: httpStatus, + ErrorCode: errorCode, + Message: fmt.Sprintf(fmtString, args...), + } +} + +// Recoverer is a middleware that recovers from panics, logs the panic (and a +// backtrace), and returns a HTTP 500 (Internal Server Error) status if +// possible. Recoverer prints a request ID if one is provided. +func recoverer(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + defer func() { + if rvr := recover(); rvr != nil { + logEntry := observability.GetLogEntry(r) + if logEntry != nil { + logEntry.Panic(rvr, debug.Stack()) + } else { + fmt.Fprintf(os.Stderr, "Panic: %+v\n", rvr) + debug.PrintStack() + } + + se := &HTTPError{ + HTTPStatus: http.StatusInternalServerError, + Message: http.StatusText(http.StatusInternalServerError), + } + HandleResponseError(se, w, r) + } + }() + next.ServeHTTP(w, r) + } + return http.HandlerFunc(fn) +} + +// ErrorCause is an error interface that contains the method Cause() for returning root cause errors +type ErrorCause interface { + Cause() error +} + +type HTTPErrorResponse20240101 struct { + Code ErrorCode `json:"code"` + Message string `json:"message"` +} + +func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { + log := observability.GetLogEntry(r).Entry + errorID := utilities.GetRequestID(r.Context()) + + apiVersion, averr := DetermineClosestAPIVersion(r.Header.Get(APIVersionHeaderName)) + if averr != nil { + log.WithError(averr).Warn("Invalid version passed to " + APIVersionHeaderName + " header, defaulting to initial version") + } else if apiVersion != APIVersionInitial { + // Echo back the determined API version from the request + w.Header().Set(APIVersionHeaderName, FormatAPIVersion(apiVersion)) + } + + switch e := err.(type) { + case *WeakPasswordError: + if apiVersion.Compare(APIVersion20240101) >= 0 { + var output struct { + HTTPErrorResponse20240101 + Payload struct { + Reasons []string `json:"reasons,omitempty"` + } `json:"weak_password,omitempty"` + } + + output.Code = ErrorCodeWeakPassword + output.Message = e.Message + output.Payload.Reasons = e.Reasons + + if jsonErr := sendJSON(w, http.StatusUnprocessableEntity, output); jsonErr != nil && jsonErr != context.DeadlineExceeded { + log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") + } + + } else { + var output struct { + HTTPError + Payload struct { + Reasons []string `json:"reasons,omitempty"` + } `json:"weak_password,omitempty"` + } + + output.HTTPStatus = http.StatusUnprocessableEntity + output.ErrorCode = ErrorCodeWeakPassword + output.Message = e.Message + output.Payload.Reasons = e.Reasons + + w.Header().Set("x-sb-error-code", output.ErrorCode) + + if jsonErr := sendJSON(w, output.HTTPStatus, output); jsonErr != nil && jsonErr != context.DeadlineExceeded { + log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") + } + } + + case *HTTPError: + switch { + case e.HTTPStatus >= http.StatusInternalServerError: + e.ErrorID = errorID + // this will get us the stack trace too + log.WithError(e.Cause()).Error(e.Error()) + case e.HTTPStatus == http.StatusTooManyRequests: + log.WithError(e.Cause()).Warn(e.Error()) + default: + log.WithError(e.Cause()).Info(e.Error()) + } + + if e.ErrorCode != "" { + w.Header().Set("x-sb-error-code", e.ErrorCode) + } + + if apiVersion.Compare(APIVersion20240101) >= 0 { + resp := HTTPErrorResponse20240101{ + Code: e.ErrorCode, + Message: e.Message, + } + + if resp.Code == "" { + if e.HTTPStatus == http.StatusInternalServerError { + resp.Code = ErrorCodeUnexpectedFailure + } else { + resp.Code = ErrorCodeUnknown + } + } + + if jsonErr := sendJSON(w, e.HTTPStatus, resp); jsonErr != nil && jsonErr != context.DeadlineExceeded { + log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") + } + } else { + if e.ErrorCode == "" { + if e.HTTPStatus == http.StatusInternalServerError { + e.ErrorCode = ErrorCodeUnexpectedFailure + } else { + e.ErrorCode = ErrorCodeUnknown + } + } + + // Provide better error messages for certain user-triggered Postgres errors. + if pgErr := utilities.NewPostgresError(e.InternalError); pgErr != nil { + if jsonErr := sendJSON(w, pgErr.HttpStatusCode, pgErr); jsonErr != nil && jsonErr != context.DeadlineExceeded { + log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") + } + return + } + + if jsonErr := sendJSON(w, e.HTTPStatus, e); jsonErr != nil && jsonErr != context.DeadlineExceeded { + log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") + } + } + + case *OAuthError: + log.WithError(e.Cause()).Info(e.Error()) + if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil && jsonErr != context.DeadlineExceeded { + log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") + } + + case ErrorCause: + HandleResponseError(e.Cause(), w, r) + + default: + log.WithError(e).Errorf("Unhandled server error: %s", e.Error()) + + if apiVersion.Compare(APIVersion20240101) >= 0 { + resp := HTTPErrorResponse20240101{ + Code: ErrorCodeUnexpectedFailure, + Message: "Unexpected failure, please check server logs for more information", + } + + if jsonErr := sendJSON(w, http.StatusInternalServerError, resp); jsonErr != nil && jsonErr != context.DeadlineExceeded { + log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") + } + } else { + httpError := HTTPError{ + HTTPStatus: http.StatusInternalServerError, + ErrorCode: ErrorCodeUnexpectedFailure, + Message: "Unexpected failure, please check server logs for more information", + } + + if jsonErr := sendJSON(w, http.StatusInternalServerError, httpError); jsonErr != nil && jsonErr != context.DeadlineExceeded { + log.WithError(jsonErr).Warn("Failed to send JSON on ResponseWriter") + } + } + } +} + +func generateFrequencyLimitErrorMessage(timeStamp *time.Time, maxFrequency time.Duration) string { + now := time.Now() + left := timeStamp.Add(maxFrequency).Sub(now) / time.Second + return fmt.Sprintf("For security purposes, you can only request this after %d seconds.", left) +} diff --git a/auth_v2.169.0/internal/api/errors_test.go b/auth_v2.169.0/internal/api/errors_test.go new file mode 100644 index 0000000..5524672 --- /dev/null +++ b/auth_v2.169.0/internal/api/errors_test.go @@ -0,0 +1,105 @@ +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/observability" +) + +func TestHandleResponseErrorWithHTTPError(t *testing.T) { + examples := []struct { + HTTPError *HTTPError + APIVersion string + ExpectedBody string + }{ + { + HTTPError: badRequestError(ErrorCodeBadJSON, "Unable to parse JSON"), + APIVersion: "", + ExpectedBody: "{\"code\":400,\"error_code\":\"" + ErrorCodeBadJSON + "\",\"msg\":\"Unable to parse JSON\"}", + }, + { + HTTPError: badRequestError(ErrorCodeBadJSON, "Unable to parse JSON"), + APIVersion: "2023-12-31", + ExpectedBody: "{\"code\":400,\"error_code\":\"" + ErrorCodeBadJSON + "\",\"msg\":\"Unable to parse JSON\"}", + }, + { + HTTPError: badRequestError(ErrorCodeBadJSON, "Unable to parse JSON"), + APIVersion: "2024-01-01", + ExpectedBody: "{\"code\":\"" + ErrorCodeBadJSON + "\",\"message\":\"Unable to parse JSON\"}", + }, + { + HTTPError: &HTTPError{ + HTTPStatus: http.StatusBadRequest, + Message: "Uncoded failure", + }, + APIVersion: "2024-01-01", + ExpectedBody: "{\"code\":\"" + ErrorCodeUnknown + "\",\"message\":\"Uncoded failure\"}", + }, + { + HTTPError: &HTTPError{ + HTTPStatus: http.StatusInternalServerError, + Message: "Unexpected failure", + }, + APIVersion: "2024-01-01", + ExpectedBody: "{\"code\":\"" + ErrorCodeUnexpectedFailure + "\",\"message\":\"Unexpected failure\"}", + }, + } + + for _, example := range examples { + rec := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, "http://example.com", nil) + require.NoError(t, err) + + if example.APIVersion != "" { + req.Header.Set(APIVersionHeaderName, example.APIVersion) + } + + HandleResponseError(example.HTTPError, rec, req) + + require.Equal(t, example.HTTPError.HTTPStatus, rec.Code) + require.Equal(t, example.ExpectedBody, rec.Body.String()) + } +} + +func TestRecoverer(t *testing.T) { + var logBuffer bytes.Buffer + config, err := conf.LoadGlobal(apiTestConfig) + require.NoError(t, err) + require.NoError(t, observability.ConfigureLogging(&config.Logging)) + + // logrus should write to the buffer so we can check if the logs are output correctly + logrus.SetOutput(&logBuffer) + panicHandler := recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("test panic") + })) + + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, "http://example.com", nil) + require.NoError(t, err) + + panicHandler.ServeHTTP(w, req) + + require.Equal(t, http.StatusInternalServerError, w.Code) + + var data HTTPError + + // panic should return an internal server error + require.NoError(t, json.NewDecoder(w.Body).Decode(&data)) + require.Equal(t, ErrorCodeUnexpectedFailure, data.ErrorCode) + require.Equal(t, http.StatusInternalServerError, data.HTTPStatus) + require.Equal(t, "Internal Server Error", data.Message) + + // panic should log the error message internally + var logs map[string]interface{} + require.NoError(t, json.NewDecoder(&logBuffer).Decode(&logs)) + require.Equal(t, "request panicked", logs["msg"]) + require.Equal(t, "test panic", logs["panic"]) + require.NotEmpty(t, logs["stack"]) +} diff --git a/auth_v2.169.0/internal/api/external.go b/auth_v2.169.0/internal/api/external.go new file mode 100644 index 0000000..768343d --- /dev/null +++ b/auth_v2.169.0/internal/api/external.go @@ -0,0 +1,684 @@ +package api + +import ( + "context" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/fatih/structs" + "github.com/gofrs/uuid" + jwt "github.com/golang-jwt/jwt/v5" + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" + "golang.org/x/oauth2" +) + +// ExternalProviderClaims are the JWT claims sent as the state in the external oauth provider signup flow +type ExternalProviderClaims struct { + AuthMicroserviceClaims + Provider string `json:"provider"` + InviteToken string `json:"invite_token,omitempty"` + Referrer string `json:"referrer,omitempty"` + FlowStateID string `json:"flow_state_id"` + LinkingTargetID string `json:"linking_target_id,omitempty"` +} + +// ExternalProviderRedirect redirects the request to the oauth provider +func (a *API) ExternalProviderRedirect(w http.ResponseWriter, r *http.Request) error { + rurl, err := a.GetExternalProviderRedirectURL(w, r, nil) + if err != nil { + return err + } + http.Redirect(w, r, rurl, http.StatusFound) + return nil +} + +// GetExternalProviderRedirectURL returns the URL to start the oauth flow with the corresponding oauth provider +func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Request, linkingTargetUser *models.User) (string, error) { + ctx := r.Context() + db := a.db.WithContext(ctx) + config := a.config + + query := r.URL.Query() + providerType := query.Get("provider") + scopes := query.Get("scopes") + codeChallenge := query.Get("code_challenge") + codeChallengeMethod := query.Get("code_challenge_method") + + p, err := a.Provider(ctx, providerType, scopes) + if err != nil { + return "", badRequestError(ErrorCodeValidationFailed, "Unsupported provider: %+v", err).WithInternalError(err) + } + + inviteToken := query.Get("invite_token") + if inviteToken != "" { + _, userErr := models.FindUserByConfirmationToken(db, inviteToken) + if userErr != nil { + if models.IsNotFoundError(userErr) { + return "", notFoundError(ErrorCodeUserNotFound, "User identified by token not found") + } + return "", internalServerError("Database error finding user").WithInternalError(userErr) + } + } + + redirectURL := utilities.GetReferrer(r, config) + log := observability.GetLogEntry(r).Entry + log.WithField("provider", providerType).Info("Redirecting to external provider") + if err := validatePKCEParams(codeChallengeMethod, codeChallenge); err != nil { + return "", err + } + flowType := getFlowFromChallenge(codeChallenge) + + flowStateID := "" + if isPKCEFlow(flowType) { + flowState, err := generateFlowState(a.db, providerType, models.OAuth, codeChallengeMethod, codeChallenge, nil) + if err != nil { + return "", err + } + flowStateID = flowState.ID.String() + } + + claims := ExternalProviderClaims{ + AuthMicroserviceClaims: AuthMicroserviceClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)), + }, + SiteURL: config.SiteURL, + InstanceID: uuid.Nil.String(), + }, + Provider: providerType, + InviteToken: inviteToken, + Referrer: redirectURL, + FlowStateID: flowStateID, + } + + if linkingTargetUser != nil { + // this means that the user is performing manual linking + claims.LinkingTargetID = linkingTargetUser.ID.String() + } + + tokenString, err := signJwt(&config.JWT, claims) + if err != nil { + return "", internalServerError("Error creating state").WithInternalError(err) + } + + authUrlParams := make([]oauth2.AuthCodeOption, 0) + query.Del("scopes") + query.Del("provider") + query.Del("code_challenge") + query.Del("code_challenge_method") + for key := range query { + if key == "workos_provider" { + // See https://workos.com/docs/reference/sso/authorize/get + authUrlParams = append(authUrlParams, oauth2.SetAuthURLParam("provider", query.Get(key))) + } else { + authUrlParams = append(authUrlParams, oauth2.SetAuthURLParam(key, query.Get(key))) + } + } + + authURL := p.AuthCodeURL(tokenString, authUrlParams...) + + return authURL, nil +} + +// ExternalProviderCallback handles the callback endpoint in the external oauth provider flow +func (a *API) ExternalProviderCallback(w http.ResponseWriter, r *http.Request) error { + rurl := a.getExternalRedirectURL(r) + u, err := url.Parse(rurl) + if err != nil { + return err + } + redirectErrors(a.internalExternalProviderCallback, w, r, u) + return nil +} + +func (a *API) handleOAuthCallback(r *http.Request) (*OAuthProviderData, error) { + ctx := r.Context() + providerType := getExternalProviderType(ctx) + + var oAuthResponseData *OAuthProviderData + var err error + switch providerType { + case "twitter": + // future OAuth1.0 providers will use this method + oAuthResponseData, err = a.oAuth1Callback(ctx, providerType) + default: + oAuthResponseData, err = a.oAuthCallback(ctx, r, providerType) + } + if err != nil { + return nil, err + } + return oAuthResponseData, nil +} + +func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + var grantParams models.GrantParams + grantParams.FillGrantParams(r) + + providerType := getExternalProviderType(ctx) + data, err := a.handleOAuthCallback(r) + if err != nil { + return err + } + + userData := data.userData + if len(userData.Emails) <= 0 { + return internalServerError("Error getting user email from external provider") + } + userData.Metadata.EmailVerified = false + for _, email := range userData.Emails { + if email.Primary { + userData.Metadata.Email = email.Email + userData.Metadata.EmailVerified = email.Verified + break + } else { + userData.Metadata.Email = email.Email + userData.Metadata.EmailVerified = email.Verified + } + } + providerAccessToken := data.token + providerRefreshToken := data.refreshToken + + var flowState *models.FlowState + // if there's a non-empty FlowStateID we perform PKCE Flow + if flowStateID := getFlowStateID(ctx); flowStateID != "" { + flowState, err = models.FindFlowStateByID(a.db, flowStateID) + if models.IsNotFoundError(err) { + return unprocessableEntityError(ErrorCodeFlowStateNotFound, "Flow state not found").WithInternalError(err) + } else if err != nil { + return internalServerError("Failed to find flow state").WithInternalError(err) + } + + } + + var user *models.User + var token *AccessTokenResponse + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + if targetUser := getTargetUser(ctx); targetUser != nil { + if user, terr = a.linkIdentityToUser(r, ctx, tx, userData, providerType); terr != nil { + return terr + } + } else if inviteToken := getInviteToken(ctx); inviteToken != "" { + if user, terr = a.processInvite(r, tx, userData, inviteToken, providerType); terr != nil { + return terr + } + } else { + if user, terr = a.createAccountFromExternalIdentity(tx, r, userData, providerType); terr != nil { + return terr + } + } + if flowState != nil { + // This means that the callback is using PKCE + flowState.ProviderAccessToken = providerAccessToken + flowState.ProviderRefreshToken = providerRefreshToken + flowState.UserID = &(user.ID) + issueTime := time.Now() + flowState.AuthCodeIssuedAt = &issueTime + + terr = tx.Update(flowState) + } else { + token, terr = a.issueRefreshToken(r, tx, user, models.OAuth, grantParams) + } + + if terr != nil { + return oauthError("server_error", terr.Error()) + } + return nil + }) + + if err != nil { + return err + } + + rurl := a.getExternalRedirectURL(r) + if flowState != nil { + // This means that the callback is using PKCE + // Set the flowState.AuthCode to the query param here + rurl, err = a.prepPKCERedirectURL(rurl, flowState.AuthCode) + if err != nil { + return err + } + } else if token != nil { + q := url.Values{} + q.Set("provider_token", providerAccessToken) + // Because not all providers give out a refresh token + // See corresponding OAuth2 spec: + if providerRefreshToken != "" { + q.Set("provider_refresh_token", providerRefreshToken) + } + + rurl = token.AsRedirectURL(rurl, q) + + } + + http.Redirect(w, r, rurl, http.StatusFound) + return nil +} + +func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http.Request, userData *provider.UserProvidedData, providerType string) (*models.User, error) { + ctx := r.Context() + aud := a.requestAud(ctx, r) + config := a.config + + var user *models.User + var identity *models.Identity + var identityData map[string]interface{} + if userData.Metadata != nil { + identityData = structs.Map(userData.Metadata) + } + + decision, terr := models.DetermineAccountLinking(tx, config, userData.Emails, aud, providerType, userData.Metadata.Subject) + if terr != nil { + return nil, terr + } + + switch decision.Decision { + case models.LinkAccount: + user = decision.User + + if identity, terr = a.createNewIdentity(tx, user, providerType, identityData); terr != nil { + return nil, terr + } + + if terr = user.UpdateUserMetaData(tx, identityData); terr != nil { + return nil, terr + } + + if terr = user.UpdateAppMetaDataProviders(tx); terr != nil { + return nil, terr + } + + case models.CreateAccount: + if config.DisableSignup { + return nil, unprocessableEntityError(ErrorCodeSignupDisabled, "Signups not allowed for this instance") + } + + params := &SignupParams{ + Provider: providerType, + Email: decision.CandidateEmail.Email, + Aud: aud, + Data: identityData, + } + + isSSOUser := false + if strings.HasPrefix(decision.LinkingDomain, "sso:") { + isSSOUser = true + } + + // because params above sets no password, this method is not + // computationally hard so it can be used within a database + // transaction + user, terr = params.ToUserModel(isSSOUser) + if terr != nil { + return nil, terr + } + + if user, terr = a.signupNewUser(tx, user); terr != nil { + return nil, terr + } + + if identity, terr = a.createNewIdentity(tx, user, providerType, identityData); terr != nil { + return nil, terr + } + user.Identities = append(user.Identities, *identity) + case models.AccountExists: + user = decision.User + identity = decision.Identities[0] + + identity.IdentityData = identityData + if terr = tx.UpdateOnly(identity, "identity_data", "last_sign_in_at"); terr != nil { + return nil, terr + } + if terr = user.UpdateUserMetaData(tx, identityData); terr != nil { + return nil, terr + } + if terr = user.UpdateAppMetaDataProviders(tx); terr != nil { + return nil, terr + } + + case models.MultipleAccounts: + return nil, internalServerError("Multiple accounts with the same email address in the same linking domain detected: %v", decision.LinkingDomain) + + default: + return nil, internalServerError("Unknown automatic linking decision: %v", decision.Decision) + } + + if user.IsBanned() { + return nil, forbiddenError(ErrorCodeUserBanned, "User is banned") + } + + if !user.IsConfirmed() { + // The user may have other unconfirmed email + password + // combination, phone or oauth identities. These identities + // need to be removed when a new oauth identity is being added + // to prevent pre-account takeover attacks from happening. + if terr = user.RemoveUnconfirmedIdentities(tx, identity); terr != nil { + return nil, internalServerError("Error updating user").WithInternalError(terr) + } + if decision.CandidateEmail.Verified || config.Mailer.Autoconfirm { + if terr := models.NewAuditLogEntry(r, tx, user, models.UserSignedUpAction, "", map[string]interface{}{ + "provider": providerType, + }); terr != nil { + return nil, terr + } + // fall through to auto-confirm and issue token + if terr = user.Confirm(tx); terr != nil { + return nil, internalServerError("Error updating user").WithInternalError(terr) + } + } else { + emailConfirmationSent := false + if decision.CandidateEmail.Email != "" { + if terr = a.sendConfirmation(r, tx, user, models.ImplicitFlow); terr != nil { + return nil, terr + } + emailConfirmationSent = true + } + if !config.Mailer.AllowUnverifiedEmailSignIns { + if emailConfirmationSent { + return nil, storage.NewCommitWithError(unprocessableEntityError(ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType))) + } + return nil, storage.NewCommitWithError(unprocessableEntityError(ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. Verify the email with %v in order to sign in", providerType, providerType))) + } + } + } else { + if terr := models.NewAuditLogEntry(r, tx, user, models.LoginAction, "", map[string]interface{}{ + "provider": providerType, + }); terr != nil { + return nil, terr + } + } + + return user, nil +} + +func (a *API) processInvite(r *http.Request, tx *storage.Connection, userData *provider.UserProvidedData, inviteToken, providerType string) (*models.User, error) { + user, err := models.FindUserByConfirmationToken(tx, inviteToken) + if err != nil { + if models.IsNotFoundError(err) { + return nil, notFoundError(ErrorCodeInviteNotFound, "Invite not found") + } + return nil, internalServerError("Database error finding user").WithInternalError(err) + } + + var emailData *provider.Email + var emails []string + for i, e := range userData.Emails { + emails = append(emails, e.Email) + if user.GetEmail() == e.Email { + emailData = &userData.Emails[i] + break + } + } + + if emailData == nil { + return nil, badRequestError(ErrorCodeValidationFailed, "Invited email does not match emails from external provider").WithInternalMessage("invited=%s external=%s", user.Email, strings.Join(emails, ", ")) + } + + var identityData map[string]interface{} + if userData.Metadata != nil { + identityData = structs.Map(userData.Metadata) + } + identity, err := a.createNewIdentity(tx, user, providerType, identityData) + if err != nil { + return nil, err + } + if err := user.UpdateAppMetaData(tx, map[string]interface{}{ + "provider": providerType, + }); err != nil { + return nil, err + } + if err := user.UpdateAppMetaDataProviders(tx); err != nil { + return nil, err + } + if err := user.UpdateUserMetaData(tx, identityData); err != nil { + return nil, internalServerError("Database error updating user").WithInternalError(err) + } + + if err := models.NewAuditLogEntry(r, tx, user, models.InviteAcceptedAction, "", map[string]interface{}{ + "provider": providerType, + }); err != nil { + return nil, err + } + + // an account with a previously unconfirmed email + password + // combination or phone may exist. so now that there is an + // OAuth identity bound to this user, and since they have not + // confirmed their email or phone, they are unaware that a + // potentially malicious door exists into their account; thus + // the password and phone needs to be removed. + if err := user.RemoveUnconfirmedIdentities(tx, identity); err != nil { + return nil, internalServerError("Error updating user").WithInternalError(err) + } + + // confirm because they were able to respond to invite email + if err := user.Confirm(tx); err != nil { + return nil, err + } + return user, nil +} + +func (a *API) loadExternalState(ctx context.Context, r *http.Request) (context.Context, error) { + var state string + switch r.Method { + case http.MethodPost: + state = r.FormValue("state") + default: + state = r.URL.Query().Get("state") + } + if state == "" { + return ctx, badRequestError(ErrorCodeBadOAuthCallback, "OAuth state parameter missing") + } + config := a.config + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods(config.JWT.ValidMethods)) + _, err := p.ParseWithClaims(state, &claims, func(token *jwt.Token) (interface{}, error) { + if kid, ok := token.Header["kid"]; ok { + if kidStr, ok := kid.(string); ok { + return conf.FindPublicKeyByKid(kidStr, &config.JWT) + } + } + if alg, ok := token.Header["alg"]; ok { + if alg == jwt.SigningMethodHS256.Name { + // preserve backward compatibility for cases where the kid is not set + return []byte(config.JWT.Secret), nil + } + } + return nil, fmt.Errorf("missing kid") + }) + if err != nil { + return ctx, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state").WithInternalError(err) + } + if claims.Provider == "" { + return ctx, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state (missing provider)") + } + if claims.InviteToken != "" { + ctx = withInviteToken(ctx, claims.InviteToken) + } + if claims.Referrer != "" { + ctx = withExternalReferrer(ctx, claims.Referrer) + } + if claims.FlowStateID != "" { + ctx = withFlowStateID(ctx, claims.FlowStateID) + } + if claims.LinkingTargetID != "" { + linkingTargetUserID, err := uuid.FromString(claims.LinkingTargetID) + if err != nil { + return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state (linking_target_id must be UUID)") + } + u, err := models.FindUserByID(a.db, linkingTargetUserID) + if err != nil { + if models.IsNotFoundError(err) { + return nil, unprocessableEntityError(ErrorCodeUserNotFound, "Linking target user not found") + } + return nil, internalServerError("Database error loading user").WithInternalError(err) + } + ctx = withTargetUser(ctx, u) + } + ctx = withExternalProviderType(ctx, claims.Provider) + return withSignature(ctx, state), nil +} + +// Provider returns a Provider interface for the given name. +func (a *API) Provider(ctx context.Context, name string, scopes string) (provider.Provider, error) { + config := a.config + name = strings.ToLower(name) + + switch name { + case "apple": + return provider.NewAppleProvider(ctx, config.External.Apple) + case "azure": + return provider.NewAzureProvider(config.External.Azure, scopes) + case "bitbucket": + return provider.NewBitbucketProvider(config.External.Bitbucket) + case "discord": + return provider.NewDiscordProvider(config.External.Discord, scopes) + case "facebook": + return provider.NewFacebookProvider(config.External.Facebook, scopes) + case "figma": + return provider.NewFigmaProvider(config.External.Figma, scopes) + case "fly": + return provider.NewFlyProvider(config.External.Fly, scopes) + case "github": + return provider.NewGithubProvider(config.External.Github, scopes) + case "gitlab": + return provider.NewGitlabProvider(config.External.Gitlab, scopes) + case "google": + return provider.NewGoogleProvider(ctx, config.External.Google, scopes) + case "kakao": + return provider.NewKakaoProvider(config.External.Kakao, scopes) + case "keycloak": + return provider.NewKeycloakProvider(config.External.Keycloak, scopes) + case "linkedin": + return provider.NewLinkedinProvider(config.External.Linkedin, scopes) + case "linkedin_oidc": + return provider.NewLinkedinOIDCProvider(config.External.LinkedinOIDC, scopes) + case "notion": + return provider.NewNotionProvider(config.External.Notion) + case "spotify": + return provider.NewSpotifyProvider(config.External.Spotify, scopes) + case "slack": + return provider.NewSlackProvider(config.External.Slack, scopes) + case "slack_oidc": + return provider.NewSlackOIDCProvider(config.External.SlackOIDC, scopes) + case "twitch": + return provider.NewTwitchProvider(config.External.Twitch, scopes) + case "twitter": + return provider.NewTwitterProvider(config.External.Twitter, scopes) + case "vercel_marketplace": + return provider.NewVercelMarketplaceProvider(config.External.VercelMarketplace, scopes) + case "workos": + return provider.NewWorkOSProvider(config.External.WorkOS) + case "zoom": + return provider.NewZoomProvider(config.External.Zoom) + default: + return nil, fmt.Errorf("Provider %s could not be found", name) + } +} + +func redirectErrors(handler apiHandler, w http.ResponseWriter, r *http.Request, u *url.URL) { + ctx := r.Context() + log := observability.GetLogEntry(r).Entry + errorID := utilities.GetRequestID(ctx) + err := handler(w, r) + if err != nil { + q := getErrorQueryString(err, errorID, log, u.Query()) + u.RawQuery = q.Encode() + + // TODO: deprecate returning error details in the query fragment + hq := url.Values{} + if q.Get("error") != "" { + hq.Set("error", q.Get("error")) + } + if q.Get("error_description") != "" { + hq.Set("error_description", q.Get("error_description")) + } + if q.Get("error_code") != "" { + hq.Set("error_code", q.Get("error_code")) + } + u.Fragment = hq.Encode() + http.Redirect(w, r, u.String(), http.StatusFound) + } +} + +func getErrorQueryString(err error, errorID string, log logrus.FieldLogger, q url.Values) *url.Values { + switch e := err.(type) { + case *HTTPError: + if e.ErrorCode == ErrorCodeSignupDisabled { + q.Set("error", "access_denied") + } else if e.ErrorCode == ErrorCodeUserBanned { + q.Set("error", "access_denied") + } else if e.ErrorCode == ErrorCodeProviderEmailNeedsVerification { + q.Set("error", "access_denied") + } else if str, ok := oauthErrorMap[e.HTTPStatus]; ok { + q.Set("error", str) + } else { + q.Set("error", "server_error") + } + if e.HTTPStatus >= http.StatusInternalServerError { + e.ErrorID = errorID + // this will get us the stack trace too + log.WithError(e.Cause()).Error(e.Error()) + } else { + log.WithError(e.Cause()).Info(e.Error()) + } + q.Set("error_description", e.Message) + q.Set("error_code", e.ErrorCode) + case *OAuthError: + q.Set("error", e.Err) + q.Set("error_description", e.Description) + log.WithError(e.Cause()).Info(e.Error()) + case ErrorCause: + return getErrorQueryString(e.Cause(), errorID, log, q) + default: + error_type, error_description := "server_error", err.Error() + + // Provide better error messages for certain user-triggered Postgres errors. + if pgErr := utilities.NewPostgresError(e); pgErr != nil { + error_description = pgErr.Message + if oauthErrorType, ok := oauthErrorMap[pgErr.HttpStatusCode]; ok { + error_type = oauthErrorType + } + } + + q.Set("error", error_type) + q.Set("error_description", error_description) + } + return &q +} + +func (a *API) getExternalRedirectURL(r *http.Request) string { + ctx := r.Context() + config := a.config + if config.External.RedirectURL != "" { + return config.External.RedirectURL + } + if er := getExternalReferrer(ctx); er != "" { + return er + } + return config.SiteURL +} + +func (a *API) createNewIdentity(tx *storage.Connection, user *models.User, providerType string, identityData map[string]interface{}) (*models.Identity, error) { + identity, err := models.NewIdentity(user, providerType, identityData) + if err != nil { + return nil, err + } + + if terr := tx.Create(identity); terr != nil { + return nil, internalServerError("Error creating identity").WithInternalError(terr) + } + + return identity, nil +} diff --git a/auth_v2.169.0/internal/api/external_apple_test.go b/auth_v2.169.0/internal/api/external_apple_test.go new file mode 100644 index 0000000..5a0b497 --- /dev/null +++ b/auth_v2.169.0/internal/api/external_apple_test.go @@ -0,0 +1,33 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "net/url" + + jwt "github.com/golang-jwt/jwt/v5" +) + +func (ts *ExternalTestSuite) TestSignupExternalApple() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=apple", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Apple.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Apple.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("email name", q.Get("scope")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("apple", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} diff --git a/auth_v2.169.0/internal/api/external_azure_test.go b/auth_v2.169.0/internal/api/external_azure_test.go new file mode 100644 index 0000000..aac124c --- /dev/null +++ b/auth_v2.169.0/internal/api/external_azure_test.go @@ -0,0 +1,269 @@ +package api + +import ( + "context" + "crypto" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + jwt "github.com/golang-jwt/jwt/v5" + "github.com/supabase/auth/internal/api/provider" +) + +const ( + azureUser string = `{"name":"Azure Test","email":"azure@example.com","sub":"azuretestid"}` + azureUserNoEmail string = `{"name":"Azure Test","sub":"azuretestid"}` +) + +func idTokenPrivateKey() *rsa.PrivateKey { + // #nosec + der, err := base64.StdEncoding.DecodeString("MIIEpAIBAAKCAQEAvklrFDsVgbhs3DOQICMqm4xdFoi/MHj/T6XH8S7wXWd0roqdWVarwCLV4y3DILkLre4PzNK+hEY5NAnoAKrsCMyyCb4Wdl8HCdJk4ojDqAig+DJw67imqZoxJMFJyIhfMJhwVK1V8GRUPATn855rygLo7wThahMJeEHNiJr3TtV6Rf35KSs7DuyoWIUSjISYabQozKqIvpdUpTpSqjlOQvjdAxggRyycBZSgLzjWhsA8metnAMO48bX4bgiHLR6Kzu/dfPyEVPfgeYpA2ebIY6GzIUxVS0yX8+ExA6jeLCkuepjLHuz5XCJtd6zzGDXr1eX7nA6ZIeUNdFbWRDnPawIDAQABAoIBABH4Qvl1HvHSJc2hvPGcAJER71SKc2uzcYDnCfu30BEyDO3Sv0tJiQyq/YHnt26mqviw66MPH9jD/PDyIou1mHa4RfPvlJV3IeYGjWprOfbrYbAuq0VHec24dv2el0YtwreHHcyRVfVOtDm6yODTzCAWqEKyNktbIuDNbgiBgetayaJecDRoFMF9TOCeMCL92iZytzAr7fi+JWtLkRS/GZRIBjbr8LJ/ueYoCRmIx3MIw0WdPp7v2ZfeRTxP7LxJZ+MAsrq2pstmZYP7K0305e0bCJX1HexfXLs2Ul7u8zaxrXL8zw4/9+/GMsAeU3ffCVnGz/RKL5+T6iuz2RotjFECgYEA+Xk7DGwRXfDg9xba1GVFGeiC4nybqZw/RfZKcz/RRJWSHRJV/ps1avtbca3B19rjI6rewZMO1NWNv/tI2BdXP8vAKUnI9OHJZ+J/eZzmqDE6qu0v0ddRFUDzCMWE0j8BjrUdy44n4NQgopcv14u0iyr9tuhGO6YXn2SuuvEkZokCgYEAw0PNnT55kpkEhXSp7An2hdBJEub9ST7hS6Kcd8let62/qUZ/t5jWigSkWC1A2bMtH55+LgudIFjiehwVzRs7jym2j4jkKZGonyAX1l9IWgXwKl7Pn49lEQH5Yk6MhnXdyLGoFTzXiUyk/fKvgXX7jow1bD3j6sAc8P495I7TyVMCgYAHg6VJrH+har37805IE3zPWPeIRuSRaUlmnBKGAigVfsPV6FV6w8YKIOQSOn+aNtecnWr0Pa+2rXAFllYNXDaej06Mb9KDvcFJRcM9MIKqEkGIIHjOQ0QH9drcKsbjZk5vs/jfxrpgxULuYstoHKclgff+aGSlK02O2YOB0f2csQKBgQCEC/MdNiWCpKXxFg7fB3HF1i/Eb56zjKlQu7uyKeQ6tG3bLEisQNg8Z5034Apt7gRC0KyluMbeHB2z1BBOLu9dBill8X3SOqVcTpiwKKlF76QVEx622YLQOJSMDXBscYK0+KchDY74U3N0JEzZcI7YPCrYcxYRJy+rLVNvn8LK7wKBgQDE8THsZ589e10F0zDBvPK56o8PJnPeH71sgdM2Co4oLzBJ6g0rpJOKfcc03fLHsoJVOAya9WZeIy6K8+WVdcPTadR07S4p8/tcK1eguu5qlmCUOzswrTKAaJoIHO7cddQp3nySIqgYtkGdHKuvlQDMQkEKJS0meOm+vdeAG2rkaA==") + if err != nil { + panic(err) + } + + privateKey, err := x509.ParsePKCS1PrivateKey(der) + if err != nil { + panic(err) + } + + privateKey.E = 65537 + + return privateKey +} + +func setupAzureOverrideVerifiers() { + provider.OverrideVerifiers["https://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/oauth2/v2.0/authorize"] = func(ctx context.Context, config *oidc.Config) *oidc.IDTokenVerifier { + pk := idTokenPrivateKey() + + return oidc.NewVerifier( + provider.IssuerAzureMicrosoft, + &oidc.StaticKeySet{ + PublicKeys: []crypto.PublicKey{ + &pk.PublicKey, + }, + }, + config, + ) + } +} + +func mintIDToken(user string) string { + var idToken struct { + Issuer string `json:"iss"` + IssuedAt int `json:"iat"` + ExpiresAt int `json:"exp"` + Audience string `json:"aud"` + + Sub string `json:"sub,omitempty"` + Name string `json:"name,omitempty"` + Email string `json:"email,omitempty"` + XmsEdov any `json:"xms_edov,omitempty"` + } + + if err := json.Unmarshal([]byte(user), &idToken); err != nil { + panic(err) + } + + now := time.Now() + + idToken.Issuer = provider.IssuerAzureMicrosoft + idToken.IssuedAt = int(now.Unix()) + idToken.ExpiresAt = int(now.Unix() + 60*60) + idToken.Audience = "testclientid" + + header := base64.RawURLEncoding.EncodeToString([]byte(`{"typ":"JWT","alg":"RS256"}`)) + + data, err := json.Marshal(idToken) + if err != nil { + panic(err) + } + + payload := base64.RawURLEncoding.EncodeToString(data) + sum := sha256.Sum256([]byte(header + "." + payload)) + + pk := idTokenPrivateKey() + sig, err := rsa.SignPKCS1v15(nil, pk, crypto.SHA256, sum[:]) + if err != nil { + panic(err) + } + + token := header + "." + payload + "." + base64.RawURLEncoding.EncodeToString(sig) + + return token +} + +func (ts *ExternalTestSuite) TestSignupExternalAzure() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=azure", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Azure.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Azure.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("openid", q.Get("scope")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("azure", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func AzureTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, code string, user string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth2/v2.0/token": + *tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.Azure.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprintf(w, `{"access_token":"azure_token","expires_in":100000,"id_token":%q}`, mintIDToken(user)) + default: + w.WriteHeader(500) + ts.Fail("unknown azure oauth call %s", r.URL.Path) + } + })) + + ts.Config.External.Azure.URL = server.URL + ts.Config.External.Azure.ApiURL = server.URL + + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalAzure_AuthorizationCode() { + setupAzureOverrideVerifiers() + + ts.Config.DisableSignup = false + tokenCount := 0 + code := "authcode" + server := AzureTestSignupSetup(ts, &tokenCount, code, azureUser) + defer server.Close() + + u := performAuthorization(ts, "azure", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, -1, "azure@example.com", "Azure Test", "azuretestid", "") +} + +func (ts *ExternalTestSuite) TestSignupExternalAzureDisableSignupErrorWhenNoUser() { + setupAzureOverrideVerifiers() + + ts.Config.DisableSignup = true + tokenCount := 0 + code := "authcode" + server := AzureTestSignupSetup(ts, &tokenCount, code, azureUser) + defer server.Close() + + u := performAuthorization(ts, "azure", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", "azure@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalAzureDisableSignupErrorWhenNoEmail() { + setupAzureOverrideVerifiers() + + ts.Config.DisableSignup = true + tokenCount := 0 + code := "authcode" + server := AzureTestSignupSetup(ts, &tokenCount, code, azureUserNoEmail) + defer server.Close() + + u := performAuthorization(ts, "azure", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "azure@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalAzureDisableSignupSuccessWithPrimaryEmail() { + setupAzureOverrideVerifiers() + + ts.Config.DisableSignup = true + + ts.createUser("azuretestid", "azure@example.com", "Azure Test", "", "") + + tokenCount := 0 + code := "authcode" + server := AzureTestSignupSetup(ts, &tokenCount, code, azureUser) + defer server.Close() + + u := performAuthorization(ts, "azure", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, -1, "azure@example.com", "Azure Test", "azuretestid", "") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalAzureSuccessWhenMatchingToken() { + setupAzureOverrideVerifiers() + + // name should be populated from Azure API + ts.createUser("azuretestid", "azure@example.com", "", "", "invite_token") + + tokenCount := 0 + code := "authcode" + server := AzureTestSignupSetup(ts, &tokenCount, code, azureUser) + defer server.Close() + + u := performAuthorization(ts, "azure", code, "invite_token") + + assertAuthorizationSuccess(ts, u, tokenCount, -1, "azure@example.com", "Azure Test", "azuretestid", "") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalAzureErrorWhenNoMatchingToken() { + setupAzureOverrideVerifiers() + + tokenCount := 0 + code := "authcode" + azureUser := `{"name":"Azure Test","avatar":{"href":"http://example.com/avatar"}}` + server := AzureTestSignupSetup(ts, &tokenCount, code, azureUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "azure", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalAzureErrorWhenWrongToken() { + setupAzureOverrideVerifiers() + + ts.createUser("azuretestid", "azure@example.com", "", "", "invite_token") + + tokenCount := 0 + code := "authcode" + azureUser := `{"name":"Azure Test","avatar":{"href":"http://example.com/avatar"}}` + server := AzureTestSignupSetup(ts, &tokenCount, code, azureUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "azure", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalAzureErrorWhenEmailDoesntMatch() { + setupAzureOverrideVerifiers() + + ts.createUser("azuretestid", "azure@example.com", "", "", "invite_token") + + tokenCount := 0 + code := "authcode" + azureUser := `{"name":"Azure Test", "email":"other@example.com", "avatar":{"href":"http://example.com/avatar"}}` + server := AzureTestSignupSetup(ts, &tokenCount, code, azureUser) + defer server.Close() + + u := performAuthorization(ts, "azure", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} diff --git a/auth_v2.169.0/internal/api/external_bitbucket_test.go b/auth_v2.169.0/internal/api/external_bitbucket_test.go new file mode 100644 index 0000000..66b3bd4 --- /dev/null +++ b/auth_v2.169.0/internal/api/external_bitbucket_test.go @@ -0,0 +1,195 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + + jwt "github.com/golang-jwt/jwt/v5" +) + +const ( + bitbucketUser string = `{"uuid":"bitbucketTestId","display_name":"Bitbucket Test","avatar":{"href":"http://example.com/avatar"}}` +) + +func (ts *ExternalTestSuite) TestSignupExternalBitbucket() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=bitbucket", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Bitbucket.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Bitbucket.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("account email", q.Get("scope")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("bitbucket", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func BitbucketTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, user string, emails string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/site/oauth2/access_token": + *tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.Bitbucket.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"bitbucket_token","expires_in":100000}`) + case "/2.0/user": + *userCount++ + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, user) + case "/2.0/user/emails": + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, emails) + default: + w.WriteHeader(500) + ts.Fail("unknown bitbucket oauth call %s", r.URL.Path) + } + })) + + ts.Config.External.Bitbucket.URL = server.URL + + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalBitbucket_AuthorizationCode() { + ts.Config.DisableSignup = false + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `{"values":[{"email":"bitbucket@example.com","is_primary":true,"is_confirmed":true}]}` + server := BitbucketTestSignupSetup(ts, &tokenCount, &userCount, code, bitbucketUser, emails) + defer server.Close() + + u := performAuthorization(ts, "bitbucket", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "bitbucket@example.com", "Bitbucket Test", "bitbucketTestId", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestSignupExternalBitbucketDisableSignupErrorWhenNoUser() { + ts.Config.DisableSignup = true + tokenCount, userCount := 0, 0 + code := "authcode" + bitbucketUser := `{"display_name":"Bitbucket Test","avatar":{"href":"http://example.com/avatar"}}` + emails := `{"values":[{"email":"bitbucket@example.com","is_primary":true,"is_confirmed":true}]}` + server := BitbucketTestSignupSetup(ts, &tokenCount, &userCount, code, bitbucketUser, emails) + defer server.Close() + + u := performAuthorization(ts, "bitbucket", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", "bitbucket@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalBitbucketDisableSignupErrorWhenNoEmail() { + ts.Config.DisableSignup = true + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `{"values":[{}]}` + server := BitbucketTestSignupSetup(ts, &tokenCount, &userCount, code, bitbucketUser, emails) + defer server.Close() + + u := performAuthorization(ts, "bitbucket", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "bitbucket@example.com") + +} + +func (ts *ExternalTestSuite) TestSignupExternalBitbucketDisableSignupSuccessWithPrimaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("bitbucketTestId", "bitbucket@example.com", "Bitbucket Test", "http://example.com/avatar", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `{"values":[{"email":"bitbucket@example.com","is_primary":true,"is_confirmed":true}]}` + server := BitbucketTestSignupSetup(ts, &tokenCount, &userCount, code, bitbucketUser, emails) + defer server.Close() + + u := performAuthorization(ts, "bitbucket", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "bitbucket@example.com", "Bitbucket Test", "bitbucketTestId", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestSignupExternalBitbucketDisableSignupSuccessWithSecondaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("bitbucketTestId", "secondary@example.com", "Bitbucket Test", "http://example.com/avatar", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `{"values":[{"email":"primary@example.com","is_primary":true,"is_confirmed":true},{"email":"secondary@example.com","is_primary":false,"is_confirmed":true}]}` + server := BitbucketTestSignupSetup(ts, &tokenCount, &userCount, code, bitbucketUser, emails) + defer server.Close() + + u := performAuthorization(ts, "bitbucket", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "secondary@example.com", "Bitbucket Test", "bitbucketTestId", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalBitbucketSuccessWhenMatchingToken() { + // name and avatar should be populated from Bitbucket API + ts.createUser("bitbucketTestId", "bitbucket@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `{"values":[{"email":"bitbucket@example.com","is_primary":true,"is_confirmed":true}]}` + server := BitbucketTestSignupSetup(ts, &tokenCount, &userCount, code, bitbucketUser, emails) + defer server.Close() + + u := performAuthorization(ts, "bitbucket", code, "invite_token") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "bitbucket@example.com", "Bitbucket Test", "bitbucketTestId", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalBitbucketErrorWhenNoMatchingToken() { + tokenCount, userCount := 0, 0 + code := "authcode" + bitbucketUser := `{"display_name":"Bitbucket Test","avatar":{"href":"http://example.com/avatar"}}` + emails := `{"values":[{"email":"bitbucket@example.com","is_primary":true,"is_confirmed":true}]}` + server := BitbucketTestSignupSetup(ts, &tokenCount, &userCount, code, bitbucketUser, emails) + defer server.Close() + + w := performAuthorizationRequest(ts, "bitbucket", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalBitbucketErrorWhenWrongToken() { + ts.createUser("bitbucketTestId", "bitbucket@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + bitbucketUser := `{"display_name":"Bitbucket Test","avatar":{"href":"http://example.com/avatar"}}` + emails := `{"values":[{"email":"bitbucket@example.com","is_primary":true,"is_confirmed":true}]}` + server := BitbucketTestSignupSetup(ts, &tokenCount, &userCount, code, bitbucketUser, emails) + defer server.Close() + + w := performAuthorizationRequest(ts, "bitbucket", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalBitbucketErrorWhenEmailDoesntMatch() { + ts.createUser("bitbucketTestId", "bitbucket@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `{"values":[{"email":"other@example.com","is_primary":true,"is_confirmed":true}]}` + server := BitbucketTestSignupSetup(ts, &tokenCount, &userCount, code, bitbucketUser, emails) + defer server.Close() + + u := performAuthorization(ts, "bitbucket", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} diff --git a/auth_v2.169.0/internal/api/external_discord_test.go b/auth_v2.169.0/internal/api/external_discord_test.go new file mode 100644 index 0000000..7b6be8d --- /dev/null +++ b/auth_v2.169.0/internal/api/external_discord_test.go @@ -0,0 +1,167 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + + jwt "github.com/golang-jwt/jwt/v5" +) + +const ( + discordUser string = `{"id":"discordTestId","avatar":"abc","email":"discord@example.com","username":"Discord Test","verified":true,"discriminator":"0001"}}` + discordUserWrongEmail string = `{"id":"discordTestId","avatar":"abc","email":"other@example.com","username":"Discord Test","verified":true}}` + discordUserNoEmail string = `{"id":"discordTestId","avatar":"abc","username":"Discord Test","verified":true}}` +) + +func (ts *ExternalTestSuite) TestSignupExternalDiscord() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=discord", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Discord.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Discord.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("email identify", q.Get("scope")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("discord", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func DiscordTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, user string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/oauth2/token": + *tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.Discord.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"discord_token","expires_in":100000}`) + case "/api/users/@me": + *userCount++ + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, user) + default: + w.WriteHeader(500) + ts.Fail("unknown discord oauth call %s", r.URL.Path) + } + })) + + ts.Config.External.Discord.URL = server.URL + + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalDiscord_AuthorizationCode() { + ts.Config.DisableSignup = false + tokenCount, userCount := 0, 0 + code := "authcode" + server := DiscordTestSignupSetup(ts, &tokenCount, &userCount, code, discordUser) + defer server.Close() + + u := performAuthorization(ts, "discord", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "discord@example.com", "Discord Test", "discordTestId", "https://cdn.discordapp.com/avatars/discordTestId/abc.png") +} + +func (ts *ExternalTestSuite) TestSignupExternalDiscordDisableSignupErrorWhenNoUser() { + ts.Config.DisableSignup = true + + tokenCount, userCount := 0, 0 + code := "authcode" + server := DiscordTestSignupSetup(ts, &tokenCount, &userCount, code, discordUser) + defer server.Close() + + u := performAuthorization(ts, "discord", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", "discord@example.com") +} +func (ts *ExternalTestSuite) TestSignupExternalDiscordDisableSignupErrorWhenEmptyEmail() { + ts.Config.DisableSignup = true + + tokenCount, userCount := 0, 0 + code := "authcode" + server := DiscordTestSignupSetup(ts, &tokenCount, &userCount, code, discordUserNoEmail) + defer server.Close() + + u := performAuthorization(ts, "discord", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "discord@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalDiscordDisableSignupSuccessWithPrimaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("discordTestId", "discord@example.com", "Discord Test", "https://cdn.discordapp.com/avatars/discordTestId/abc.png", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := DiscordTestSignupSetup(ts, &tokenCount, &userCount, code, discordUser) + defer server.Close() + + u := performAuthorization(ts, "discord", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "discord@example.com", "Discord Test", "discordTestId", "https://cdn.discordapp.com/avatars/discordTestId/abc.png") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalDiscordSuccessWhenMatchingToken() { + // name and avatar should be populated from Discord API + ts.createUser("discordTestId", "discord@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := DiscordTestSignupSetup(ts, &tokenCount, &userCount, code, discordUser) + defer server.Close() + + u := performAuthorization(ts, "discord", code, "invite_token") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "discord@example.com", "Discord Test", "discordTestId", "https://cdn.discordapp.com/avatars/discordTestId/abc.png") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalDiscordErrorWhenNoMatchingToken() { + tokenCount, userCount := 0, 0 + code := "authcode" + server := DiscordTestSignupSetup(ts, &tokenCount, &userCount, code, discordUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "discord", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalDiscordErrorWhenWrongToken() { + ts.createUser("discordTestId", "discord@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := DiscordTestSignupSetup(ts, &tokenCount, &userCount, code, discordUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "discord", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalDiscordErrorWhenEmailDoesntMatch() { + ts.createUser("discordTestId", "discord@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := DiscordTestSignupSetup(ts, &tokenCount, &userCount, code, discordUserWrongEmail) + defer server.Close() + + u := performAuthorization(ts, "discord", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} diff --git a/auth_v2.169.0/internal/api/external_facebook_test.go b/auth_v2.169.0/internal/api/external_facebook_test.go new file mode 100644 index 0000000..c1864bb --- /dev/null +++ b/auth_v2.169.0/internal/api/external_facebook_test.go @@ -0,0 +1,167 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + + jwt "github.com/golang-jwt/jwt/v5" +) + +const ( + facebookUser string = `{"id":"facebookTestId","name":"Facebook Test","first_name":"Facebook","last_name":"Test","email":"facebook@example.com","picture":{"data":{"url":"http://example.com/avatar"}}}}` + facebookUserWrongEmail string = `{"id":"facebookTestId","name":"Facebook Test","first_name":"Facebook","last_name":"Test","email":"other@example.com","picture":{"data":{"url":"http://example.com/avatar"}}}}` + facebookUserNoEmail string = `{"id":"facebookTestId","name":"Facebook Test","first_name":"Facebook","last_name":"Test","picture":{"data":{"url":"http://example.com/avatar"}}}}` +) + +func (ts *ExternalTestSuite) TestSignupExternalFacebook() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=facebook", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Facebook.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Facebook.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("email", q.Get("scope")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("facebook", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func FacebookTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, user string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth/access_token": + *tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.Facebook.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"facebook_token","expires_in":100000}`) + case "/me": + *userCount++ + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, user) + default: + w.WriteHeader(500) + ts.Fail("unknown facebook oauth call %s", r.URL.Path) + } + })) + + ts.Config.External.Facebook.URL = server.URL + + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalFacebook_AuthorizationCode() { + ts.Config.DisableSignup = false + tokenCount, userCount := 0, 0 + code := "authcode" + server := FacebookTestSignupSetup(ts, &tokenCount, &userCount, code, facebookUser) + defer server.Close() + + u := performAuthorization(ts, "facebook", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "facebook@example.com", "Facebook Test", "facebookTestId", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestSignupExternalFacebookDisableSignupErrorWhenNoUser() { + ts.Config.DisableSignup = true + + tokenCount, userCount := 0, 0 + code := "authcode" + server := FacebookTestSignupSetup(ts, &tokenCount, &userCount, code, facebookUser) + defer server.Close() + + u := performAuthorization(ts, "facebook", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", "facebook@example.com") +} +func (ts *ExternalTestSuite) TestSignupExternalFacebookDisableSignupErrorWhenEmptyEmail() { + ts.Config.DisableSignup = true + + tokenCount, userCount := 0, 0 + code := "authcode" + server := FacebookTestSignupSetup(ts, &tokenCount, &userCount, code, facebookUserNoEmail) + defer server.Close() + + u := performAuthorization(ts, "facebook", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "facebook@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalFacebookDisableSignupSuccessWithPrimaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("facebookTestId", "facebook@example.com", "Facebook Test", "http://example.com/avatar", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := FacebookTestSignupSetup(ts, &tokenCount, &userCount, code, facebookUser) + defer server.Close() + + u := performAuthorization(ts, "facebook", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "facebook@example.com", "Facebook Test", "facebookTestId", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalFacebookSuccessWhenMatchingToken() { + // name and avatar should be populated from Facebook API + ts.createUser("facebookTestId", "facebook@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := FacebookTestSignupSetup(ts, &tokenCount, &userCount, code, facebookUser) + defer server.Close() + + u := performAuthorization(ts, "facebook", code, "invite_token") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "facebook@example.com", "Facebook Test", "facebookTestId", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalFacebookErrorWhenNoMatchingToken() { + tokenCount, userCount := 0, 0 + code := "authcode" + server := FacebookTestSignupSetup(ts, &tokenCount, &userCount, code, facebookUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "facebook", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalFacebookErrorWhenWrongToken() { + ts.createUser("facebookTestId", "facebook@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := FacebookTestSignupSetup(ts, &tokenCount, &userCount, code, facebookUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "facebook", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalFacebookErrorWhenEmailDoesntMatch() { + ts.createUser("facebookTestId", "facebook@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := FacebookTestSignupSetup(ts, &tokenCount, &userCount, code, facebookUserWrongEmail) + defer server.Close() + + u := performAuthorization(ts, "facebook", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} diff --git a/auth_v2.169.0/internal/api/external_figma_test.go b/auth_v2.169.0/internal/api/external_figma_test.go new file mode 100644 index 0000000..6e119b9 --- /dev/null +++ b/auth_v2.169.0/internal/api/external_figma_test.go @@ -0,0 +1,264 @@ +package api + +import ( + "bytes" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "time" + + jwt "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/models" +) + +func (ts *ExternalTestSuite) TestSignupExternalFigma() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=figma", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Figma.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Figma.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("files:read", q.Get("scope")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("figma", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func FigmaTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, email string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/oauth/token": + *tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.Figma.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"figma_token","expires_in":100000,"refresh_token":"figma_token"}`) + case "/v1/me": + *userCount++ + w.Header().Add("Content-Type", "application/json") + fmt.Fprintf(w, `{"id":"figma-test-id","email":"%s","handle":"Figma Test","img_url":"http://example.com/avatar"}`, email) + default: + w.WriteHeader(500) + ts.Fail("unknown figma oauth call %s", r.URL.Path) + } + })) + + ts.Config.External.Figma.URL = server.URL + + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalFigma_AuthorizationCode() { + tokenCount, userCount := 0, 0 + code := "authcode" + email := "figma@example.com" + server := FigmaTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "figma", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "figma@example.com", "Figma Test", "figma-test-id", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestSignupExternalFigma_PKCE() { + tokenCount, userCount := 0, 0 + code := "authcode" + + // for the plain challenge method, the code verifier == code challenge + // code challenge has to be between 43 - 128 chars for the plain challenge method + codeVerifier := "testtesttesttesttesttesttesttesttesttesttesttesttesttest" + + email := "figma@example.com" + server := FigmaTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + cases := []struct { + desc string + codeChallengeMethod string + }{ + { + desc: "SHA256", + codeChallengeMethod: "s256", + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + var codeChallenge string + if c.codeChallengeMethod == "s256" { + hashedCodeVerifier := sha256.Sum256([]byte(codeVerifier)) + codeChallenge = base64.RawURLEncoding.EncodeToString(hashedCodeVerifier[:]) + } else { + codeChallenge = codeVerifier + } + // Check for valid auth code returned + u := performPKCEAuthorization(ts, "figma", code, codeChallenge, c.codeChallengeMethod) + m, err := url.ParseQuery(u.RawQuery) + authCode := m["code"][0] + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), authCode) + + // Check for valid provider access token, mock does not return refresh token + user, err := models.FindUserByEmailAndAudience(ts.API.db, "figma@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), user) + flowState, err := models.FindFlowStateByAuthCode(ts.API.db, authCode) + require.NoError(ts.T(), err) + require.Equal(ts.T(), "figma_token", flowState.ProviderAccessToken) + + // Exchange Auth Code for token + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "code_verifier": codeVerifier, + "auth_code": authCode, + })) + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=pkce", &buffer) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + // Validate that access token and provider tokens are present + data := AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.NotEmpty(ts.T(), data.Token) + require.NotEmpty(ts.T(), data.RefreshToken) + require.NotEmpty(ts.T(), data.ProviderAccessToken) + require.Equal(ts.T(), data.User.ID, user.ID) + }) + } +} + +func (ts *ExternalTestSuite) TestSignupExternalFigmaDisableSignupErrorWhenNoUser() { + ts.Config.DisableSignup = true + tokenCount, userCount := 0, 0 + code := "authcode" + email := "figma@example.com" + server := FigmaTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "figma", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", "figma@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalFigmaDisableSignupErrorWhenEmptyEmail() { + ts.Config.DisableSignup = true + tokenCount, userCount := 0, 0 + code := "authcode" + email := "" + server := FigmaTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "figma", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "figma@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalFigmaDisableSignupSuccessWithPrimaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("figma-test-id", "figma@example.com", "Figma Test", "http://example.com/avatar", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + email := "figma@example.com" + server := FigmaTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "figma", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "figma@example.com", "Figma Test", "figma-test-id", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalFigmaSuccessWhenMatchingToken() { + // name and avatar should be populated from Figma API + ts.createUser("figma-test-id", "figma@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + email := "figma@example.com" + server := FigmaTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "figma", code, "invite_token") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "figma@example.com", "Figma Test", "figma-test-id", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalFigmaErrorWhenNoMatchingToken() { + tokenCount, userCount := 0, 0 + code := "authcode" + email := "figma@example.com" + server := FigmaTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + w := performAuthorizationRequest(ts, "figma", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalFigmaErrorWhenWrongToken() { + ts.createUser("figma-test-id", "figma@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + email := "figma@example.com" + server := FigmaTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + w := performAuthorizationRequest(ts, "figma", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalFigmaErrorWhenEmailDoesntMatch() { + ts.createUser("figma-test-id", "figma@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + email := "other@example.com" + server := FigmaTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "figma", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} + +func (ts *ExternalTestSuite) TestSignupExternalFigmaErrorWhenUserBanned() { + tokenCount, userCount := 0, 0 + code := "authcode" + email := "figma@example.com" + + server := FigmaTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "figma", code, "") + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "figma@example.com", "Figma Test", "figma-test-id", "http://example.com/avatar") + + user, err := models.FindUserByEmailAndAudience(ts.API.db, "figma@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + t := time.Now().Add(24 * time.Hour) + user.BannedUntil = &t + require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) + + u = performAuthorization(ts, "figma", code, "") + assertAuthorizationFailure(ts, u, "User is banned", "access_denied", "") +} diff --git a/auth_v2.169.0/internal/api/external_fly_test.go b/auth_v2.169.0/internal/api/external_fly_test.go new file mode 100644 index 0000000..cf357c9 --- /dev/null +++ b/auth_v2.169.0/internal/api/external_fly_test.go @@ -0,0 +1,264 @@ +package api + +import ( + "bytes" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "time" + + jwt "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/models" +) + +func (ts *ExternalTestSuite) TestSignupExternalFly() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=fly", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Fly.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Fly.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("read", q.Get("scope")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("fly", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func FlyTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, email string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth/token": + *tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.Fly.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"fly_token","expires_in":100000,"refresh_token":"fly_refresh_token"}`) + case "/oauth/token/info": + *userCount++ + w.Header().Add("Content-Type", "application/json") + fmt.Fprintf(w, `{"resource_owner_id":"test_resource_owner_id","scope":["read"],"expires_in":1111,"application":{"uid":"test_app_uid"},"created_at":1696003692,"user_id":"test_user_id","user_name":"test_user","email":"%s","organizations":[{"id":"test_org_id","role":"test"}]}`, email) + default: + w.WriteHeader(500) + ts.Fail("unknown fly oauth call %s", r.URL.Path) + } + })) + + ts.Config.External.Fly.URL = server.URL + + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalFly_AuthorizationCode() { + tokenCount, userCount := 0, 0 + code := "authcode" + email := "fly@example.com" + server := FlyTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "fly", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "fly@example.com", "test_user", "test_user_id", "") +} + +func (ts *ExternalTestSuite) TestSignupExternalFly_PKCE() { + tokenCount, userCount := 0, 0 + code := "authcode" + + // for the plain challenge method, the code verifier == code challenge + // code challenge has to be between 43 - 128 chars for the plain challenge method + codeVerifier := "testtesttesttesttesttesttesttesttesttesttesttesttesttest" + + email := "fly@example.com" + server := FlyTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + cases := []struct { + desc string + codeChallengeMethod string + }{ + { + desc: "SHA256", + codeChallengeMethod: "s256", + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + var codeChallenge string + if c.codeChallengeMethod == "s256" { + hashedCodeVerifier := sha256.Sum256([]byte(codeVerifier)) + codeChallenge = base64.RawURLEncoding.EncodeToString(hashedCodeVerifier[:]) + } else { + codeChallenge = codeVerifier + } + // Check for valid auth code returned + u := performPKCEAuthorization(ts, "fly", code, codeChallenge, c.codeChallengeMethod) + m, err := url.ParseQuery(u.RawQuery) + authCode := m["code"][0] + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), authCode) + + // Check for valid provider access token, mock does not return refresh token + user, err := models.FindUserByEmailAndAudience(ts.API.db, "fly@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), user) + flowState, err := models.FindFlowStateByAuthCode(ts.API.db, authCode) + require.NoError(ts.T(), err) + require.Equal(ts.T(), "fly_token", flowState.ProviderAccessToken) + + // Exchange Auth Code for token + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "code_verifier": codeVerifier, + "auth_code": authCode, + })) + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=pkce", &buffer) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + // Validate that access token and provider tokens are present + data := AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.NotEmpty(ts.T(), data.Token) + require.NotEmpty(ts.T(), data.RefreshToken) + require.NotEmpty(ts.T(), data.ProviderAccessToken) + require.Equal(ts.T(), data.User.ID, user.ID) + }) + } +} + +func (ts *ExternalTestSuite) TestSignupExternalFlyDisableSignupErrorWhenNoUser() { + ts.Config.DisableSignup = true + tokenCount, userCount := 0, 0 + code := "authcode" + email := "fly@example.com" + server := FlyTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "fly", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", email) +} + +func (ts *ExternalTestSuite) TestSignupExternalFlyDisableSignupErrorWhenEmptyEmail() { + ts.Config.DisableSignup = true + tokenCount, userCount := 0, 0 + code := "authcode" + email := "" + server := FlyTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "fly", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "fly@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalFlyDisableSignupSuccessWithPrimaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("test_user_id", "fly@example.com", "test_user", "", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + email := "fly@example.com" + server := FlyTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "fly", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "fly@example.com", "test_user", "test_user_id", "") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalFlySuccessWhenMatchingToken() { + // name and avatar should be populated from fly API + ts.createUser("test_user_id", "fly@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + email := "fly@example.com" + server := FlyTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "fly", code, "invite_token") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "fly@example.com", "test_user", "test_user_id", "") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalFlyErrorWhenNoMatchingToken() { + tokenCount, userCount := 0, 0 + code := "authcode" + email := "fly@example.com" + server := FlyTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + w := performAuthorizationRequest(ts, "fly", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalFlyErrorWhenWrongToken() { + ts.createUser("test_user_id", "fly@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + email := "fly@example.com" + server := FlyTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + w := performAuthorizationRequest(ts, "fly", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalFlyErrorWhenEmailDoesntMatch() { + ts.createUser("test_user_id", "fly@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + email := "other@example.com" + server := FlyTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "fly", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} + +func (ts *ExternalTestSuite) TestSignupExternalFlyErrorWhenUserBanned() { + tokenCount, userCount := 0, 0 + code := "authcode" + email := "fly@example.com" + + server := FlyTestSignupSetup(ts, &tokenCount, &userCount, code, email) + defer server.Close() + + u := performAuthorization(ts, "fly", code, "") + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "fly@example.com", "test_user", "test_user_id", "") + + user, err := models.FindUserByEmailAndAudience(ts.API.db, "fly@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + t := time.Now().Add(24 * time.Hour) + user.BannedUntil = &t + require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) + + u = performAuthorization(ts, "fly", code, "") + assertAuthorizationFailure(ts, u, "User is banned", "access_denied", "") +} diff --git a/auth_v2.169.0/internal/api/external_github_test.go b/auth_v2.169.0/internal/api/external_github_test.go new file mode 100644 index 0000000..7b9d31e --- /dev/null +++ b/auth_v2.169.0/internal/api/external_github_test.go @@ -0,0 +1,300 @@ +package api + +import ( + "bytes" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "time" + + jwt "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/models" +) + +func (ts *ExternalTestSuite) TestSignupExternalGithub() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=github", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Github.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Github.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("user:email", q.Get("scope")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("github", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func GitHubTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, emails string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/login/oauth/access_token": + *tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.Github.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"github_token","expires_in":100000}`) + case "/api/v3/user": + *userCount++ + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"id":123, "name":"GitHub Test","avatar_url":"http://example.com/avatar"}`) + case "/api/v3/user/emails": + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, emails) + default: + w.WriteHeader(500) + ts.Fail("unknown github oauth call %s", r.URL.Path) + } + })) + + ts.Config.External.Github.URL = server.URL + + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalGitHub_AuthorizationCode() { + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"github@example.com", "primary": true, "verified": true}]` + server := GitHubTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + u := performAuthorization(ts, "github", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "github@example.com", "GitHub Test", "123", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestSignupExternalGitHub_PKCE() { + tokenCount, userCount := 0, 0 + code := "authcode" + + // for the plain challenge method, the code verifier == code challenge + // code challenge has to be between 43 - 128 chars for the plain challenge method + codeVerifier := "testtesttesttesttesttesttesttesttesttesttesttesttesttest" + + emails := `[{"email":"github@example.com", "primary": true, "verified": true}]` + server := GitHubTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + cases := []struct { + desc string + codeChallengeMethod string + }{ + { + desc: "SHA256", + codeChallengeMethod: "s256", + }, + { + desc: "Plain", + codeChallengeMethod: "plain", + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + var codeChallenge string + if c.codeChallengeMethod == "s256" { + hashedCodeVerifier := sha256.Sum256([]byte(codeVerifier)) + codeChallenge = base64.RawURLEncoding.EncodeToString(hashedCodeVerifier[:]) + } else { + codeChallenge = codeVerifier + } + // Check for valid auth code returned + u := performPKCEAuthorization(ts, "github", code, codeChallenge, c.codeChallengeMethod) + m, err := url.ParseQuery(u.RawQuery) + authCode := m["code"][0] + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), authCode) + + // Check for valid provider access token, mock does not return refresh token + user, err := models.FindUserByEmailAndAudience(ts.API.db, "github@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), user) + flowState, err := models.FindFlowStateByAuthCode(ts.API.db, authCode) + require.NoError(ts.T(), err) + require.Equal(ts.T(), "github_token", flowState.ProviderAccessToken) + + // Exchange Auth Code for token + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "code_verifier": codeVerifier, + "auth_code": authCode, + })) + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=pkce", &buffer) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + // Validate that access token and provider tokens are present + data := AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.NotEmpty(ts.T(), data.Token) + require.NotEmpty(ts.T(), data.RefreshToken) + require.NotEmpty(ts.T(), data.ProviderAccessToken) + require.Equal(ts.T(), data.User.ID, user.ID) + }) + } + +} + +func (ts *ExternalTestSuite) TestSignupExternalGitHubDisableSignupErrorWhenNoUser() { + ts.Config.DisableSignup = true + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"github@example.com", "primary": true, "verified": true}]` + server := GitHubTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + u := performAuthorization(ts, "github", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", "github@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalGitHubDisableSignupErrorWhenEmptyEmail() { + ts.Config.DisableSignup = true + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"primary": true, "verified": true}]` + server := GitHubTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + u := performAuthorization(ts, "github", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "github@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalGitHubDisableSignupSuccessWithPrimaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("123", "github@example.com", "GitHub Test", "http://example.com/avatar", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"github@example.com", "primary": true, "verified": true}]` + server := GitHubTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + u := performAuthorization(ts, "github", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "github@example.com", "GitHub Test", "123", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestSignupExternalGitHubDisableSignupSuccessWithNonPrimaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("123", "secondary@example.com", "GitHub Test", "http://example.com/avatar", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"primary@example.com", "primary": true, "verified": true},{"email":"secondary@example.com", "primary": false, "verified": true}]` + server := GitHubTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + u := performAuthorization(ts, "github", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "secondary@example.com", "GitHub Test", "123", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalGitHubSuccessWhenMatchingToken() { + // name and avatar should be populated from GitHub API + ts.createUser("123", "github@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"github@example.com", "primary": true, "verified": true}]` + server := GitHubTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + u := performAuthorization(ts, "github", code, "invite_token") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "github@example.com", "GitHub Test", "123", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalGitHubErrorWhenNoMatchingToken() { + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"github@example.com", "primary": true, "verified": true}]` + server := GitHubTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + w := performAuthorizationRequest(ts, "github", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalGitHubErrorWhenWrongToken() { + ts.createUser("123", "github@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"github@example.com", "primary": true, "verified": true}]` + server := GitHubTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + w := performAuthorizationRequest(ts, "github", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalGitHubErrorWhenEmailDoesntMatch() { + ts.createUser("123", "github@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"other@example.com", "primary": true, "verified": true}]` + server := GitHubTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + u := performAuthorization(ts, "github", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} + +func (ts *ExternalTestSuite) TestSignupExternalGitHubErrorWhenVerifiedFalse() { + ts.Config.Mailer.AllowUnverifiedEmailSignIns = false + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"github@example.com", "primary": true, "verified": false}]` + server := GitHubTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + u := performAuthorization(ts, "github", code, "") + + assertAuthorizationFailure(ts, u, "Unverified email with github. A confirmation email has been sent to your github email", "access_denied", "") +} + +func (ts *ExternalTestSuite) TestSignupExternalGitHubErrorWhenUserBanned() { + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"github@example.com", "primary": true, "verified": true}]` + server := GitHubTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + u := performAuthorization(ts, "github", code, "") + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "github@example.com", "GitHub Test", "123", "http://example.com/avatar") + + user, err := models.FindUserByEmailAndAudience(ts.API.db, "github@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + t := time.Now().Add(24 * time.Hour) + user.BannedUntil = &t + require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) + + u = performAuthorization(ts, "github", code, "") + assertAuthorizationFailure(ts, u, "User is banned", "access_denied", "") +} diff --git a/auth_v2.169.0/internal/api/external_gitlab_test.go b/auth_v2.169.0/internal/api/external_gitlab_test.go new file mode 100644 index 0000000..5a14a0a --- /dev/null +++ b/auth_v2.169.0/internal/api/external_gitlab_test.go @@ -0,0 +1,199 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + + jwt "github.com/golang-jwt/jwt/v5" +) + +const ( + gitlabUser string = `{"id":123,"email":"gitlab@example.com","name":"GitLab Test","avatar_url":"http://example.com/avatar","confirmed_at":"2012-05-23T09:05:22Z"}` + gitlabUserWrongEmail string = `{"id":123,"email":"other@example.com","name":"GitLab Test","avatar_url":"http://example.com/avatar","confirmed_at":"2012-05-23T09:05:22Z"}` + gitlabUserNoEmail string = `{"id":123,"name":"Gitlab Test","avatar_url":"http://example.com/avatar"}` +) + +func (ts *ExternalTestSuite) TestSignupExternalGitlab() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=gitlab", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Gitlab.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Gitlab.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("read_user", q.Get("scope")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("gitlab", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func GitlabTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, user string, emails string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth/token": + *tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.Gitlab.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"gitlab_token","expires_in":100000}`) + case "/api/v4/user": + *userCount++ + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, user) + case "/api/v4/user/emails": + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, emails) + default: + w.WriteHeader(500) + ts.Fail("unknown gitlab oauth call %s", r.URL.Path) + } + })) + + ts.Config.External.Gitlab.URL = server.URL + + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalGitlab_AuthorizationCode() { + // additional emails from GitLab don't return confirm status + ts.Config.Mailer.Autoconfirm = true + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"id":1,"email":"gitlab@example.com"}]` + server := GitlabTestSignupSetup(ts, &tokenCount, &userCount, code, gitlabUser, emails) + defer server.Close() + + u := performAuthorization(ts, "gitlab", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "gitlab@example.com", "GitLab Test", "123", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestSignupExternalGitLabDisableSignupErrorWhenNoUser() { + ts.Config.DisableSignup = true + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"id":1,"email":"gitlab@example.com"}]` + server := GitlabTestSignupSetup(ts, &tokenCount, &userCount, code, gitlabUser, emails) + defer server.Close() + + u := performAuthorization(ts, "gitlab", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", "gitlab@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalGitLabDisableSignupErrorWhenEmptyEmail() { + ts.Config.DisableSignup = true + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[]` + server := GitlabTestSignupSetup(ts, &tokenCount, &userCount, code, gitlabUserNoEmail, emails) + defer server.Close() + + u := performAuthorization(ts, "gitlab", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "gitlab@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalGitLabDisableSignupSuccessWithPrimaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("123", "gitlab@example.com", "GitLab Test", "http://example.com/avatar", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := "[]" + server := GitlabTestSignupSetup(ts, &tokenCount, &userCount, code, gitlabUser, emails) + defer server.Close() + + u := performAuthorization(ts, "gitlab", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "gitlab@example.com", "GitLab Test", "123", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestSignupExternalGitLabDisableSignupSuccessWithSecondaryEmail() { + // additional emails from GitLab don't return confirm status + ts.Config.Mailer.Autoconfirm = true + ts.Config.DisableSignup = true + + ts.createUser("123", "secondary@example.com", "GitLab Test", "http://example.com/avatar", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"id":1,"email":"secondary@example.com"}]` + server := GitlabTestSignupSetup(ts, &tokenCount, &userCount, code, gitlabUser, emails) + defer server.Close() + + u := performAuthorization(ts, "gitlab", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "secondary@example.com", "GitLab Test", "123", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalGitLabSuccessWhenMatchingToken() { + // name and avatar should be populated from GitLab API + ts.createUser("123", "gitlab@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := "[]" + server := GitlabTestSignupSetup(ts, &tokenCount, &userCount, code, gitlabUser, emails) + defer server.Close() + + u := performAuthorization(ts, "gitlab", code, "invite_token") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "gitlab@example.com", "GitLab Test", "123", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalGitLabErrorWhenNoMatchingToken() { + tokenCount, userCount := 0, 0 + code := "authcode" + emails := "[]" + server := GitlabTestSignupSetup(ts, &tokenCount, &userCount, code, gitlabUser, emails) + defer server.Close() + + w := performAuthorizationRequest(ts, "gitlab", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalGitLabErrorWhenWrongToken() { + ts.createUser("123", "gitlab@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := "[]" + server := GitlabTestSignupSetup(ts, &tokenCount, &userCount, code, gitlabUser, emails) + defer server.Close() + + w := performAuthorizationRequest(ts, "gitlab", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalGitLabErrorWhenEmailDoesntMatch() { + ts.createUser("123", "gitlab@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := "[]" + server := GitlabTestSignupSetup(ts, &tokenCount, &userCount, code, gitlabUserWrongEmail, emails) + defer server.Close() + + u := performAuthorization(ts, "gitlab", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} diff --git a/auth_v2.169.0/internal/api/external_google_test.go b/auth_v2.169.0/internal/api/external_google_test.go new file mode 100644 index 0000000..7b3b6d1 --- /dev/null +++ b/auth_v2.169.0/internal/api/external_google_test.go @@ -0,0 +1,181 @@ +package api + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + + jwt "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/api/provider" +) + +const ( + googleUser string = `{"id":"googleTestId","name":"Google Test","picture":"http://example.com/avatar","email":"google@example.com","verified_email":true}}` + googleUserWrongEmail string = `{"id":"googleTestId","name":"Google Test","picture":"http://example.com/avatar","email":"other@example.com","verified_email":true}}` + googleUserNoEmail string = `{"id":"googleTestId","name":"Google Test","picture":"http://example.com/avatar","verified_email":false}}` +) + +func (ts *ExternalTestSuite) TestSignupExternalGoogle() { + provider.ResetGoogleProvider() + + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=google", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Google.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Google.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("email profile", q.Get("scope")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("google", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func GoogleTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, user string) *httptest.Server { + provider.ResetGoogleProvider() + + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/.well-known/openid-configuration": + w.Header().Add("Content-Type", "application/json") + require.NoError(ts.T(), json.NewEncoder(w).Encode(map[string]any{ + "issuer": server.URL, + "token_endpoint": server.URL + "/o/oauth2/token", + })) + case "/o/oauth2/token": + *tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.Google.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"google_token","expires_in":100000}`) + case "/userinfo/v2/me": + *userCount++ + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, user) + default: + w.WriteHeader(500) + ts.Fail("unknown google oauth call %s", r.URL.Path) + } + })) + + provider.OverrideGoogleProvider(server.URL, server.URL+"/userinfo/v2/me") + + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalGoogle_AuthorizationCode() { + ts.Config.DisableSignup = false + tokenCount, userCount := 0, 0 + code := "authcode" + server := GoogleTestSignupSetup(ts, &tokenCount, &userCount, code, googleUser) + defer server.Close() + + u := performAuthorization(ts, "google", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "google@example.com", "Google Test", "googleTestId", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestSignupExternalGoogleDisableSignupErrorWhenNoUser() { + ts.Config.DisableSignup = true + + tokenCount, userCount := 0, 0 + code := "authcode" + server := GoogleTestSignupSetup(ts, &tokenCount, &userCount, code, googleUser) + defer server.Close() + + u := performAuthorization(ts, "google", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", "google@example.com") +} +func (ts *ExternalTestSuite) TestSignupExternalGoogleDisableSignupErrorWhenEmptyEmail() { + ts.Config.DisableSignup = true + + tokenCount, userCount := 0, 0 + code := "authcode" + server := GoogleTestSignupSetup(ts, &tokenCount, &userCount, code, googleUserNoEmail) + defer server.Close() + + u := performAuthorization(ts, "google", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "google@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalGoogleDisableSignupSuccessWithPrimaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("googleTestId", "google@example.com", "Google Test", "http://example.com/avatar", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := GoogleTestSignupSetup(ts, &tokenCount, &userCount, code, googleUser) + defer server.Close() + + u := performAuthorization(ts, "google", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "google@example.com", "Google Test", "googleTestId", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalGoogleSuccessWhenMatchingToken() { + // name and avatar should be populated from Google API + ts.createUser("googleTestId", "google@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := GoogleTestSignupSetup(ts, &tokenCount, &userCount, code, googleUser) + defer server.Close() + + u := performAuthorization(ts, "google", code, "invite_token") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "google@example.com", "Google Test", "googleTestId", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalGoogleErrorWhenNoMatchingToken() { + tokenCount, userCount := 0, 0 + code := "authcode" + server := GoogleTestSignupSetup(ts, &tokenCount, &userCount, code, googleUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "google", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalGoogleErrorWhenWrongToken() { + ts.createUser("googleTestId", "google@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := GoogleTestSignupSetup(ts, &tokenCount, &userCount, code, googleUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "google", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalGoogleErrorWhenEmailDoesntMatch() { + ts.createUser("googleTestId", "google@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := GoogleTestSignupSetup(ts, &tokenCount, &userCount, code, googleUserWrongEmail) + defer server.Close() + + u := performAuthorization(ts, "google", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} diff --git a/auth_v2.169.0/internal/api/external_kakao_test.go b/auth_v2.169.0/internal/api/external_kakao_test.go new file mode 100644 index 0000000..729f723 --- /dev/null +++ b/auth_v2.169.0/internal/api/external_kakao_test.go @@ -0,0 +1,238 @@ +package api + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "time" + + jwt "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/models" +) + +func (ts *ExternalTestSuite) TestSignupExternalKakao() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=kakao", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Kakao.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Kakao.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("kakao", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func KakaoTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, emails string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth/token": + *tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.Kakao.RedirectURI, r.FormValue("redirect_uri")) + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"kakao_token","expires_in":100000}`) + case "/v2/user/me": + *userCount++ + var emailList []provider.Email + if err := json.Unmarshal([]byte(emails), &emailList); err != nil { + ts.Fail("Invalid email json %s", emails) + } + + var email *provider.Email + + for i, e := range emailList { + if len(e.Email) > 0 { + email = &emailList[i] + break + } + } + + w.Header().Add("Content-Type", "application/json") + if email != nil { + fmt.Fprintf(w, ` + { + "id":123, + "kakao_account": { + "profile": { + "nickname":"Kakao Test", + "profile_image_url":"http://example.com/avatar" + }, + "email": "%v", + "is_email_valid": %v, + "is_email_verified": %v + } + }`, email.Email, email.Verified, email.Verified) + } else { + fmt.Fprint(w, ` + { + "id":123, + "kakao_account": { + "profile": { + "nickname":"Kakao Test", + "profile_image_url":"http://example.com/avatar" + } + } + }`) + } + default: + w.WriteHeader(500) + ts.Fail("unknown kakao oauth call %s", r.URL.Path) + } + })) + ts.Config.External.Kakao.URL = server.URL + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalKakao_AuthorizationCode() { + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"kakao@example.com", "primary": true, "verified": true}]` + server := KakaoTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + u := performAuthorization(ts, "kakao", code, "") + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "kakao@example.com", "Kakao Test", "123", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestSignupExternalKakaoDisableSignupErrorWhenNoUser() { + ts.Config.DisableSignup = true + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"kakao@example.com", "primary": true, "verified": true}]` + server := KakaoTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + u := performAuthorization(ts, "kakao", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", "kakao@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalKakaoDisableSignupErrorWhenEmptyEmail() { + ts.Config.DisableSignup = true + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"primary": true, "verified": true}]` + server := KakaoTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + u := performAuthorization(ts, "kakao", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "kakao@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalKakaoDisableSignupSuccessWithPrimaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("123", "kakao@example.com", "Kakao Test", "http://example.com/avatar", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"kakao@example.com", "primary": true, "verified": true}]` + server := KakaoTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + u := performAuthorization(ts, "kakao", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "kakao@example.com", "Kakao Test", "123", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalKakaoSuccessWhenMatchingToken() { + // name and avatar should be populated from Kakao API + ts.createUser("123", "kakao@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"kakao@example.com", "primary": true, "verified": true}]` + server := KakaoTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + u := performAuthorization(ts, "kakao", code, "invite_token") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "kakao@example.com", "Kakao Test", "123", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalKakaoErrorWhenNoMatchingToken() { + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"kakao@example.com", "primary": true, "verified": true}]` + server := KakaoTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + w := performAuthorizationRequest(ts, "kakao", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalKakaoErrorWhenWrongToken() { + ts.createUser("123", "kakao@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"kakao@example.com", "primary": true, "verified": true}]` + server := KakaoTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + w := performAuthorizationRequest(ts, "kakao", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalKakaoErrorWhenEmailDoesntMatch() { + ts.createUser("123", "kakao@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"other@example.com", "primary": true, "verified": true}]` + server := KakaoTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + u := performAuthorization(ts, "kakao", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} + +func (ts *ExternalTestSuite) TestSignupExternalKakaoErrorWhenVerifiedFalse() { + ts.Config.Mailer.AllowUnverifiedEmailSignIns = false + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"kakao@example.com", "primary": true, "verified": false}]` + server := KakaoTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + u := performAuthorization(ts, "kakao", code, "") + + assertAuthorizationFailure(ts, u, "Unverified email with kakao. A confirmation email has been sent to your kakao email", "access_denied", "") +} + +func (ts *ExternalTestSuite) TestSignupExternalKakaoErrorWhenUserBanned() { + tokenCount, userCount := 0, 0 + code := "authcode" + emails := `[{"email":"kakao@example.com", "primary": true, "verified": true}]` + server := KakaoTestSignupSetup(ts, &tokenCount, &userCount, code, emails) + defer server.Close() + + u := performAuthorization(ts, "kakao", code, "") + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "kakao@example.com", "Kakao Test", "123", "http://example.com/avatar") + + user, err := models.FindUserByEmailAndAudience(ts.API.db, "kakao@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + t := time.Now().Add(24 * time.Hour) + user.BannedUntil = &t + require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) + + u = performAuthorization(ts, "kakao", code, "") + assertAuthorizationFailure(ts, u, "User is banned", "access_denied", "") +} diff --git a/auth_v2.169.0/internal/api/external_keycloak_test.go b/auth_v2.169.0/internal/api/external_keycloak_test.go new file mode 100644 index 0000000..a0952ea --- /dev/null +++ b/auth_v2.169.0/internal/api/external_keycloak_test.go @@ -0,0 +1,182 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + + jwt "github.com/golang-jwt/jwt/v5" +) + +const ( + keycloakUser string = `{"sub": "keycloaktestid", "name": "Keycloak Test", "email": "keycloak@example.com", "preferred_username": "keycloak", "email_verified": true}` + keycloakUserNoEmail string = `{"sub": "keycloaktestid", "name": "Keycloak Test", "preferred_username": "keycloak", "email_verified": false}` +) + +func (ts *ExternalTestSuite) TestSignupExternalKeycloak() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=keycloak", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Keycloak.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Keycloak.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("profile email", q.Get("scope")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("keycloak", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func KeycloakTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, user string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/protocol/openid-connect/token": + *tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.Keycloak.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"keycloak_token","expires_in":100000}`) + case "/protocol/openid-connect/userinfo": + *userCount++ + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, user) + default: + w.WriteHeader(500) + ts.Fail("unknown keycloak oauth call %s", r.URL.Path) + } + })) + + ts.Config.External.Keycloak.URL = server.URL + + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalKeycloakWithoutURLSetup() { + ts.createUser("keycloaktestid", "keycloak@example.com", "Keycloak Test", "", "") + tokenCount, userCount := 0, 0 + code := "authcode" + server := KeycloakTestSignupSetup(ts, &tokenCount, &userCount, code, keycloakUser) + ts.Config.External.Keycloak.URL = "" + defer server.Close() + + w := performAuthorizationRequest(ts, "keycloak", code) + ts.Equal(w.Code, http.StatusBadRequest) +} + +func (ts *ExternalTestSuite) TestSignupExternalKeycloak_AuthorizationCode() { + ts.Config.DisableSignup = false + ts.createUser("keycloaktestid", "keycloak@example.com", "Keycloak Test", "", "") + tokenCount, userCount := 0, 0 + code := "authcode" + server := KeycloakTestSignupSetup(ts, &tokenCount, &userCount, code, keycloakUser) + defer server.Close() + + u := performAuthorization(ts, "keycloak", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "keycloak@example.com", "Keycloak Test", "keycloaktestid", "") +} + +func (ts *ExternalTestSuite) TestSignupExternalKeycloakDisableSignupErrorWhenNoUser() { + ts.Config.DisableSignup = true + tokenCount, userCount := 0, 0 + code := "authcode" + server := KeycloakTestSignupSetup(ts, &tokenCount, &userCount, code, keycloakUser) + defer server.Close() + + u := performAuthorization(ts, "keycloak", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", "keycloak@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalKeycloakDisableSignupErrorWhenNoEmail() { + ts.Config.DisableSignup = true + tokenCount, userCount := 0, 0 + code := "authcode" + server := KeycloakTestSignupSetup(ts, &tokenCount, &userCount, code, keycloakUserNoEmail) + defer server.Close() + + u := performAuthorization(ts, "keycloak", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "keycloak@example.com") + +} + +func (ts *ExternalTestSuite) TestSignupExternalKeycloakDisableSignupSuccessWithPrimaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("keycloaktestid", "keycloak@example.com", "Keycloak Test", "", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := KeycloakTestSignupSetup(ts, &tokenCount, &userCount, code, keycloakUser) + defer server.Close() + + u := performAuthorization(ts, "keycloak", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "keycloak@example.com", "Keycloak Test", "keycloaktestid", "") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalKeycloakSuccessWhenMatchingToken() { + // name and avatar should be populated from Keycloak API + ts.createUser("keycloaktestid", "keycloak@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := KeycloakTestSignupSetup(ts, &tokenCount, &userCount, code, keycloakUser) + defer server.Close() + + u := performAuthorization(ts, "keycloak", code, "invite_token") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "keycloak@example.com", "Keycloak Test", "keycloaktestid", "") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalKeycloakErrorWhenNoMatchingToken() { + tokenCount, userCount := 0, 0 + code := "authcode" + keycloakUser := `{"name":"Keycloak Test","avatar":{"href":"http://example.com/avatar"}}` + server := KeycloakTestSignupSetup(ts, &tokenCount, &userCount, code, keycloakUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "keycloak", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalKeycloakErrorWhenWrongToken() { + ts.createUser("keycloaktestid", "keycloak@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + keycloakUser := `{"name":"Keycloak Test","avatar":{"href":"http://example.com/avatar"}}` + server := KeycloakTestSignupSetup(ts, &tokenCount, &userCount, code, keycloakUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "keycloak", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalKeycloakErrorWhenEmailDoesntMatch() { + ts.createUser("keycloaktestid", "keycloak@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + keycloakUser := `{"name":"Keycloak Test", "email":"other@example.com", "avatar":{"href":"http://example.com/avatar"}}` + server := KeycloakTestSignupSetup(ts, &tokenCount, &userCount, code, keycloakUser) + defer server.Close() + + u := performAuthorization(ts, "keycloak", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} diff --git a/auth_v2.169.0/internal/api/external_linkedin_test.go b/auth_v2.169.0/internal/api/external_linkedin_test.go new file mode 100644 index 0000000..fe49932 --- /dev/null +++ b/auth_v2.169.0/internal/api/external_linkedin_test.go @@ -0,0 +1,170 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + + jwt "github.com/golang-jwt/jwt/v5" +) + +const ( + linkedinUser string = `{"id":"linkedinTestId","firstName":{"localized":{"en_US":"Linkedin"},"preferredLocale":{"country":"US","language":"en"}},"lastName":{"localized":{"en_US":"Test"},"preferredLocale":{"country":"US","language":"en"}},"profilePicture":{"displayImage~":{"elements":[{"identifiers":[{"identifier":"http://example.com/avatar"}]}]}}}` + linkedinUserNoProfilePic string = `{"id":"linkedinTestId","firstName":{"localized":{"en_US":"Linkedin"},"preferredLocale":{"country":"US","language":"en"}},"lastName":{"localized":{"en_US":"Test"},"preferredLocale":{"country":"US","language":"en"}},"profilePicture":{"displayImage~":{"elements":[]}}}` + linkedinEmail string = `{"elements": [{"handle": "","handle~": {"emailAddress": "linkedin@example.com"}}]}` + linkedinWrongEmail string = `{"elements": [{"handle": "","handle~": {"emailAddress": "other@example.com"}}]}` +) + +func (ts *ExternalTestSuite) TestSignupExternalLinkedin() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=linkedin", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Linkedin.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Linkedin.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("r_emailaddress r_liteprofile", q.Get("scope")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("linkedin", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func LinkedinTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, user string, email string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth/v2/accessToken": + *tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.Linkedin.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"linkedin_token","expires_in":100000}`) + case "/v2/me": + *userCount++ + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, user) + case "/v2/emailAddress": + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, email) + default: + w.WriteHeader(500) + ts.Fail("unknown linkedin oauth call %s", r.URL.Path) + } + })) + + ts.Config.External.Linkedin.URL = server.URL + + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalLinkedin_AuthorizationCode() { + ts.Config.DisableSignup = false + tokenCount, userCount := 0, 0 + code := "authcode" + server := LinkedinTestSignupSetup(ts, &tokenCount, &userCount, code, linkedinUser, linkedinEmail) + defer server.Close() + + u := performAuthorization(ts, "linkedin", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "linkedin@example.com", "Linkedin Test", "linkedinTestId", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestSignupExternalLinkedinDisableSignupErrorWhenNoUser() { + ts.Config.DisableSignup = true + + tokenCount, userCount := 0, 0 + code := "authcode" + server := LinkedinTestSignupSetup(ts, &tokenCount, &userCount, code, linkedinUser, linkedinEmail) + defer server.Close() + + u := performAuthorization(ts, "linkedin", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", "linkedin@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalLinkedinDisableSignupSuccessWithPrimaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("linkedinTestId", "linkedin@example.com", "Linkedin Test", "http://example.com/avatar", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := LinkedinTestSignupSetup(ts, &tokenCount, &userCount, code, linkedinUser, linkedinEmail) + defer server.Close() + + u := performAuthorization(ts, "linkedin", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "linkedin@example.com", "Linkedin Test", "linkedinTestId", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalLinkedinSuccessWhenMatchingToken() { + // name and avatar should be populated from Linkedin API + ts.createUser("linkedinTestId", "linkedin@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := LinkedinTestSignupSetup(ts, &tokenCount, &userCount, code, linkedinUser, linkedinEmail) + defer server.Close() + + u := performAuthorization(ts, "linkedin", code, "invite_token") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "linkedin@example.com", "Linkedin Test", "linkedinTestId", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalLinkedinErrorWhenNoMatchingToken() { + tokenCount, userCount := 0, 0 + code := "authcode" + server := LinkedinTestSignupSetup(ts, &tokenCount, &userCount, code, linkedinUser, linkedinEmail) + defer server.Close() + + w := performAuthorizationRequest(ts, "linkedin", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalLinkedinErrorWhenWrongToken() { + ts.createUser("linkedinTestId", "linkedin@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := LinkedinTestSignupSetup(ts, &tokenCount, &userCount, code, linkedinUser, linkedinEmail) + defer server.Close() + + w := performAuthorizationRequest(ts, "linkedin", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalLinkedinErrorWhenEmailDoesntMatch() { + ts.createUser("linkedinTestId", "linkedin@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := LinkedinTestSignupSetup(ts, &tokenCount, &userCount, code, linkedinUser, linkedinWrongEmail) + defer server.Close() + + u := performAuthorization(ts, "linkedin", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} + +func (ts *ExternalTestSuite) TestSignupExternalLinkedin_MissingProfilePic() { + tokenCount, userCount := 0, 0 + code := "authcode" + server := LinkedinTestSignupSetup(ts, &tokenCount, &userCount, code, linkedinUserNoProfilePic, linkedinEmail) + defer server.Close() + + u := performAuthorization(ts, "linkedin", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "linkedin@example.com", "Linkedin Test", "linkedinTestId", "") +} diff --git a/auth_v2.169.0/internal/api/external_notion_test.go b/auth_v2.169.0/internal/api/external_notion_test.go new file mode 100644 index 0000000..268e449 --- /dev/null +++ b/auth_v2.169.0/internal/api/external_notion_test.go @@ -0,0 +1,170 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + + jwt "github.com/golang-jwt/jwt/v5" +) + +const ( + notionUser string = `{"bot":{"owner":{"user":{"id":"notionTestId","name":"Notion Test","avatar_url":"http://example.com/avatar","person":{"email":"notion@example.com"},"verified_email":true}}}}` + notionUserWrongEmail string = `{"bot":{"owner":{"user":{"id":"notionTestId","name":"Notion Test","avatar_url":"http://example.com/avatar","person":{"email":"other@example.com"},"verified_email":true}}}}` + notionUserNoEmail string = `{"bot":{"owner":{"user":{"id":"notionTestId","name":"Notion Test","avatar_url":"http://example.com/avatar","verified_email":true}}}}` +) + +func (ts *ExternalTestSuite) TestSignupExternalNotion() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=notion", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Notion.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Notion.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("notion", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func NotionTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, user string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v1/oauth/token": + *tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.Notion.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"notion_token","expires_in":100000}`) + case "/v1/users/me": + *userCount++ + ts.Contains(r.Header, "Authorization") + ts.Contains(r.Header, "Notion-Version") + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, user) + default: + w.WriteHeader(500) + ts.Fail("unknown notion oauth call %s", r.URL.Path) + } + })) + + ts.Config.External.Notion.URL = server.URL + + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalNotion_AuthorizationCode() { + ts.Config.DisableSignup = false + tokenCount, userCount := 0, 0 + code := "authcode" + server := NotionTestSignupSetup(ts, &tokenCount, &userCount, code, notionUser) + defer server.Close() + + u := performAuthorization(ts, "notion", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "notion@example.com", "Notion Test", "notionTestId", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestSignupExternalNotionDisableSignupErrorWhenNoUser() { + ts.Config.DisableSignup = true + + tokenCount, userCount := 0, 0 + code := "authcode" + server := NotionTestSignupSetup(ts, &tokenCount, &userCount, code, notionUser) + defer server.Close() + + u := performAuthorization(ts, "notion", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", "notion@example.com") +} +func (ts *ExternalTestSuite) TestSignupExternalNotionDisableSignupErrorWhenEmptyEmail() { + ts.Config.DisableSignup = true + + tokenCount, userCount := 0, 0 + code := "authcode" + server := NotionTestSignupSetup(ts, &tokenCount, &userCount, code, notionUserNoEmail) + defer server.Close() + + u := performAuthorization(ts, "notion", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "notion@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalNotionDisableSignupSuccessWithPrimaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("notionTestId", "notion@example.com", "Notion Test", "http://example.com/avatar", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := NotionTestSignupSetup(ts, &tokenCount, &userCount, code, notionUser) + defer server.Close() + + u := performAuthorization(ts, "notion", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "notion@example.com", "Notion Test", "notionTestId", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalNotionSuccessWhenMatchingToken() { + // name and avatar should be populated from Notion API + ts.createUser("notionTestId", "notion@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := NotionTestSignupSetup(ts, &tokenCount, &userCount, code, notionUser) + defer server.Close() + + u := performAuthorization(ts, "notion", code, "invite_token") + + fmt.Printf("%+v\n", u) + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "notion@example.com", "Notion Test", "notionTestId", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalNotionErrorWhenNoMatchingToken() { + tokenCount, userCount := 0, 0 + code := "authcode" + server := NotionTestSignupSetup(ts, &tokenCount, &userCount, code, notionUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "notion", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalNotionErrorWhenWrongToken() { + ts.createUser("notionTestId", "notion@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := NotionTestSignupSetup(ts, &tokenCount, &userCount, code, notionUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "notion", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalNotionErrorWhenEmailDoesntMatch() { + ts.createUser("notionTestId", "notion@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := NotionTestSignupSetup(ts, &tokenCount, &userCount, code, notionUserWrongEmail) + defer server.Close() + + u := performAuthorization(ts, "notion", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} diff --git a/auth_v2.169.0/internal/api/external_oauth.go b/auth_v2.169.0/internal/api/external_oauth.go new file mode 100644 index 0000000..cb098e3 --- /dev/null +++ b/auth_v2.169.0/internal/api/external_oauth.go @@ -0,0 +1,155 @@ +package api + +import ( + "context" + "fmt" + "net/http" + "net/url" + + "github.com/mrjones/oauth" + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/utilities" +) + +// OAuthProviderData contains the userData and token returned by the oauth provider +type OAuthProviderData struct { + userData *provider.UserProvidedData + token string + refreshToken string + code string +} + +// loadFlowState parses the `state` query parameter as a JWS payload, +// extracting the provider requested +func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Context, error) { + ctx := r.Context() + oauthToken := r.URL.Query().Get("oauth_token") + if oauthToken != "" { + ctx = withRequestToken(ctx, oauthToken) + } + oauthVerifier := r.URL.Query().Get("oauth_verifier") + if oauthVerifier != "" { + ctx = withOAuthVerifier(ctx, oauthVerifier) + } + + var err error + ctx, err = a.loadExternalState(ctx, r) + if err != nil { + u, uerr := url.ParseRequestURI(a.config.SiteURL) + if uerr != nil { + return ctx, internalServerError("site url is improperly formatted").WithInternalError(uerr) + } + + q := getErrorQueryString(err, utilities.GetRequestID(ctx), observability.GetLogEntry(r).Entry, u.Query()) + u.RawQuery = q.Encode() + + http.Redirect(w, r, u.String(), http.StatusSeeOther) + } + return ctx, err +} + +func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType string) (*OAuthProviderData, error) { + var rq url.Values + if err := r.ParseForm(); r.Method == http.MethodPost && err == nil { + rq = r.Form + } else { + rq = r.URL.Query() + } + + extError := rq.Get("error") + if extError != "" { + return nil, oauthError(extError, rq.Get("error_description")) + } + + oauthCode := rq.Get("code") + if oauthCode == "" { + return nil, badRequestError(ErrorCodeBadOAuthCallback, "OAuth callback with missing authorization code missing") + } + + oAuthProvider, err := a.OAuthProvider(ctx, providerType) + if err != nil { + return nil, badRequestError(ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err) + } + + log := observability.GetLogEntry(r).Entry + log.WithFields(logrus.Fields{ + "provider": providerType, + "code": oauthCode, + }).Debug("Exchanging oauth code") + + token, err := oAuthProvider.GetOAuthToken(oauthCode) + if err != nil { + return nil, internalServerError("Unable to exchange external code: %s", oauthCode).WithInternalError(err) + } + + userData, err := oAuthProvider.GetUserData(ctx, token) + if err != nil { + return nil, internalServerError("Error getting user profile from external provider").WithInternalError(err) + } + + switch externalProvider := oAuthProvider.(type) { + case *provider.AppleProvider: + // apple only returns user info the first time + oauthUser := rq.Get("user") + if oauthUser != "" { + err := externalProvider.ParseUser(oauthUser, userData) + if err != nil { + return nil, err + } + } + } + + return &OAuthProviderData{ + userData: userData, + token: token.AccessToken, + refreshToken: token.RefreshToken, + code: oauthCode, + }, nil +} + +func (a *API) oAuth1Callback(ctx context.Context, providerType string) (*OAuthProviderData, error) { + oAuthProvider, err := a.OAuthProvider(ctx, providerType) + if err != nil { + return nil, badRequestError(ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err) + } + oauthToken := getRequestToken(ctx) + oauthVerifier := getOAuthVerifier(ctx) + var accessToken *oauth.AccessToken + var userData *provider.UserProvidedData + if twitterProvider, ok := oAuthProvider.(*provider.TwitterProvider); ok { + accessToken, err = twitterProvider.Consumer.AuthorizeToken(&oauth.RequestToken{ + Token: oauthToken, + }, oauthVerifier) + if err != nil { + return nil, internalServerError("Unable to retrieve access token").WithInternalError(err) + } + userData, err = twitterProvider.FetchUserData(ctx, accessToken) + if err != nil { + return nil, internalServerError("Error getting user email from external provider").WithInternalError(err) + } + } + + return &OAuthProviderData{ + userData: userData, + token: accessToken.Token, + refreshToken: "", + }, nil + +} + +// OAuthProvider returns the corresponding oauth provider as an OAuthProvider interface +func (a *API) OAuthProvider(ctx context.Context, name string) (provider.OAuthProvider, error) { + providerCandidate, err := a.Provider(ctx, name, "") + if err != nil { + return nil, err + } + + switch p := providerCandidate.(type) { + case provider.OAuthProvider: + return p, nil + default: + return nil, fmt.Errorf("Provider %v cannot be used for OAuth", name) + } +} diff --git a/auth_v2.169.0/internal/api/external_slack_oidc_test.go b/auth_v2.169.0/internal/api/external_slack_oidc_test.go new file mode 100644 index 0000000..acd2e78 --- /dev/null +++ b/auth_v2.169.0/internal/api/external_slack_oidc_test.go @@ -0,0 +1,33 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "net/url" + + jwt "github.com/golang-jwt/jwt/v5" +) + +func (ts *ExternalTestSuite) TestSignupExternalSlackOIDC() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=slack_oidc", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Slack.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Slack.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("profile email openid", q.Get("scope")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("slack_oidc", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} diff --git a/auth_v2.169.0/internal/api/external_test.go b/auth_v2.169.0/internal/api/external_test.go new file mode 100644 index 0000000..bef89d7 --- /dev/null +++ b/auth_v2.169.0/internal/api/external_test.go @@ -0,0 +1,254 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +type ExternalTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestExternal(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &ExternalTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *ExternalTestSuite) SetupTest() { + ts.Config.DisableSignup = false + ts.Config.Mailer.Autoconfirm = false + + models.TruncateAll(ts.API.db) +} + +func (ts *ExternalTestSuite) createUser(providerId string, email string, name string, avatar string, confirmationToken string) (*models.User, error) { + // Cleanup existing user, if they already exist + if u, _ := models.FindUserByEmailAndAudience(ts.API.db, email, ts.Config.JWT.Aud); u != nil { + require.NoError(ts.T(), ts.API.db.Destroy(u), "Error deleting user") + } + + userData := map[string]interface{}{"provider_id": providerId, "full_name": name} + if avatar != "" { + userData["avatar_url"] = avatar + } + u, err := models.NewUser("", email, "test", ts.Config.JWT.Aud, userData) + + if confirmationToken != "" { + u.ConfirmationToken = confirmationToken + } + ts.Require().NoError(err, "Error making new user") + ts.Require().NoError(ts.API.db.Create(u), "Error creating user") + + if confirmationToken != "" { + ts.Require().NoError(models.CreateOneTimeToken(ts.API.db, u.ID, email, u.ConfirmationToken, models.ConfirmationToken), "Error creating one-time confirmation/invite token") + } + + i, err := models.NewIdentity(u, "email", map[string]interface{}{ + "sub": u.ID.String(), + "email": email, + }) + ts.Require().NoError(err) + ts.Require().NoError(ts.API.db.Create(i), "Error creating identity") + + return u, err +} + +func performAuthorizationRequest(ts *ExternalTestSuite, provider string, inviteToken string) *httptest.ResponseRecorder { + authorizeURL := "http://localhost/authorize?provider=" + provider + if inviteToken != "" { + authorizeURL = authorizeURL + "&invite_token=" + inviteToken + } + + req := httptest.NewRequest(http.MethodGet, authorizeURL, nil) + req.Header.Set("Referer", "https://example.netlify.com/admin") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + return w +} + +func performPKCEAuthorizationRequest(ts *ExternalTestSuite, provider, codeChallenge, codeChallengeMethod string) *httptest.ResponseRecorder { + authorizeURL := "http://localhost/authorize?provider=" + provider + if codeChallenge != "" { + authorizeURL = authorizeURL + "&code_challenge=" + codeChallenge + "&code_challenge_method=" + codeChallengeMethod + } + + req := httptest.NewRequest(http.MethodGet, authorizeURL, nil) + req.Header.Set("Referer", "https://example.supabase.com/admin") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + return w +} + +func performPKCEAuthorization(ts *ExternalTestSuite, provider, code, codeChallenge, codeChallengeMethod string) *url.URL { + w := performPKCEAuthorizationRequest(ts, provider, codeChallenge, codeChallengeMethod) + ts.Require().Equal(http.StatusFound, w.Code) + // Get code and state from the redirect + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + state := q.Get("state") + testURL, err := url.Parse("http://localhost/callback") + ts.Require().NoError(err) + v := testURL.Query() + v.Set("code", code) + v.Set("state", state) + testURL.RawQuery = v.Encode() + // Use the code to get a token + req := httptest.NewRequest(http.MethodGet, testURL.String(), nil) + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err = url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + + return u + +} + +func performAuthorization(ts *ExternalTestSuite, provider string, code string, inviteToken string) *url.URL { + w := performAuthorizationRequest(ts, provider, inviteToken) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + state := q.Get("state") + + // auth server callback + testURL, err := url.Parse("http://localhost/callback") + ts.Require().NoError(err) + v := testURL.Query() + v.Set("code", code) + v.Set("state", state) + testURL.RawQuery = v.Encode() + req := httptest.NewRequest(http.MethodGet, testURL.String(), nil) + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err = url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + ts.Require().Equal("/admin", u.Path) + + return u +} + +func assertAuthorizationSuccess(ts *ExternalTestSuite, u *url.URL, tokenCount int, userCount int, email string, name string, providerId string, avatar string) { + // ensure redirect has #access_token=... + v, err := url.ParseQuery(u.RawQuery) + ts.Require().NoError(err) + ts.Require().Empty(v.Get("error_description")) + ts.Require().Empty(v.Get("error")) + + v, err = url.ParseQuery(u.Fragment) + ts.Require().NoError(err) + ts.NotEmpty(v.Get("access_token")) + ts.NotEmpty(v.Get("refresh_token")) + ts.NotEmpty(v.Get("expires_in")) + ts.Equal("bearer", v.Get("token_type")) + + ts.Equal(1, tokenCount) + if userCount > -1 { + ts.Equal(1, userCount) + } + + // ensure user has been created with metadata + user, err := models.FindUserByEmailAndAudience(ts.API.db, email, ts.Config.JWT.Aud) + ts.Require().NoError(err) + ts.Equal(providerId, user.UserMetaData["provider_id"]) + ts.Equal(name, user.UserMetaData["full_name"]) + if avatar == "" { + ts.Equal(nil, user.UserMetaData["avatar_url"]) + } else { + ts.Equal(avatar, user.UserMetaData["avatar_url"]) + } +} + +func assertAuthorizationFailure(ts *ExternalTestSuite, u *url.URL, errorDescription string, errorType string, email string) { + // ensure new sign ups error + v, err := url.ParseQuery(u.RawQuery) + ts.Require().NoError(err) + ts.Require().Equal(errorDescription, v.Get("error_description")) + ts.Require().Equal(errorType, v.Get("error")) + + v, err = url.ParseQuery(u.Fragment) + ts.Require().NoError(err) + ts.Empty(v.Get("access_token")) + ts.Empty(v.Get("refresh_token")) + ts.Empty(v.Get("expires_in")) + ts.Empty(v.Get("token_type")) + + // ensure user is nil + user, err := models.FindUserByEmailAndAudience(ts.API.db, email, ts.Config.JWT.Aud) + ts.Require().Error(err, "User not found") + ts.Require().Nil(user) +} + +// TestSignupExternalUnsupported tests API /authorize for an unsupported external provider +func (ts *ExternalTestSuite) TestSignupExternalUnsupported() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=external", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Equal(w.Code, http.StatusBadRequest) +} + +func (ts *ExternalTestSuite) TestRedirectErrorsShouldPreserveParams() { + // Request with invalid external provider + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=external", nil) + w := httptest.NewRecorder() + cases := []struct { + Desc string + RedirectURL string + QueryParams []string + ErrorMessage string + }{ + { + Desc: "Should preserve redirect query params on error", + RedirectURL: "http://example.com/path?paramforpreservation=value2", + QueryParams: []string{"paramforpreservation"}, + ErrorMessage: "invalid_request", + }, + { + Desc: "Error param should be overwritten", + RedirectURL: "http://example.com/path?error=abc", + QueryParams: []string{"error"}, + ErrorMessage: "invalid_request", + }, + } + for _, c := range cases { + parsedURL, err := url.Parse(c.RedirectURL) + require.Equal(ts.T(), err, nil) + + redirectErrors(ts.API.internalExternalProviderCallback, w, req, parsedURL) + + parsedParams, err := url.ParseQuery(parsedURL.RawQuery) + require.Equal(ts.T(), err, nil) + + // An error and description should be returned + expectedQueryParams := append(c.QueryParams, "error", "error_description") + + for _, expectedQueryParam := range expectedQueryParams { + val, exists := parsedParams[expectedQueryParam] + require.True(ts.T(), exists) + if expectedQueryParam == "error" { + require.Equal(ts.T(), val[0], c.ErrorMessage) + } + } + } +} diff --git a/auth_v2.169.0/internal/api/external_twitch_test.go b/auth_v2.169.0/internal/api/external_twitch_test.go new file mode 100644 index 0000000..694a5ff --- /dev/null +++ b/auth_v2.169.0/internal/api/external_twitch_test.go @@ -0,0 +1,171 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + + jwt "github.com/golang-jwt/jwt/v5" +) + +const ( + twitchUser string = `{"data":[{"id":"twitchTestId","login":"Twitch user","display_name":"Twitch user","type":"","broadcaster_type":"","description":"","profile_image_url":"https://s.gravatar.com/avatar/23463b99b62a72f26ed677cc556c44e8","offline_image_url":"","email":"twitch@example.com"}]}` + twitchUserWrongEmail string = `{"data":[{"id":"twitchTestId","login":"Twitch user","display_name":"Twitch user","type":"","broadcaster_type":"","description":"","profile_image_url":"https://s.gravatar.com/avatar/23463b99b62a72f26ed677cc556c44e8","offline_image_url":"","email":"other@example.com"}]}` +) + +func (ts *ExternalTestSuite) TestSignupExternalTwitch() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=twitch", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Twitch.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Twitch.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("user:read:email", q.Get("scope")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("twitch", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func TwitchTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, user string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth2/token": + *tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.Twitch.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"Twitch_token","expires_in":100000}`) + case "/helix/users": + *userCount++ + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, user) + default: + w.WriteHeader(500) + ts.Fail("unknown Twitch oauth call %s", r.URL.Path) + } + })) + + ts.Config.External.Twitch.URL = server.URL + + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalTwitch_AuthorizationCode() { + ts.Config.DisableSignup = false + tokenCount, userCount := 0, 0 + code := "authcode" + server := TwitchTestSignupSetup(ts, &tokenCount, &userCount, code, twitchUser) + defer server.Close() + + u := performAuthorization(ts, "twitch", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "twitch@example.com", "Twitch user", "twitchTestId", "https://s.gravatar.com/avatar/23463b99b62a72f26ed677cc556c44e8") +} + +func (ts *ExternalTestSuite) TestSignupExternalTwitchDisableSignupErrorWhenNoUser() { + ts.Config.DisableSignup = true + + tokenCount, userCount := 0, 0 + code := "authcode" + TwitchUser := `{"data":[{"id":"1","login":"Twitch user","display_name":"Twitch user","type":"","broadcaster_type":"","description":"","profile_image_url":"https://s.gravatar.com/avatar/23463b99b62a72f26ed677cc556c44e8","offline_image_url":"","email":"twitch@example.com"}]}` + server := TwitchTestSignupSetup(ts, &tokenCount, &userCount, code, TwitchUser) + defer server.Close() + + u := performAuthorization(ts, "twitch", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", "twitch@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalTwitchDisableSignupErrorWhenEmptyEmail() { + ts.Config.DisableSignup = true + + tokenCount, userCount := 0, 0 + code := "authcode" + TwitchUser := `{"data":[{"id":"1","login":"Twitch user","display_name":"Twitch user","type":"","broadcaster_type":"","description":"","profile_image_url":"https://s.gravatar.com/avatar/23463b99b62a72f26ed677cc556c44e8","offline_image_url":""}]}` + server := TwitchTestSignupSetup(ts, &tokenCount, &userCount, code, TwitchUser) + defer server.Close() + + u := performAuthorization(ts, "twitch", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "twitch@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalTwitchDisableSignupSuccessWithPrimaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("twitchTestId", "twitch@example.com", "Twitch user", "https://s.gravatar.com/avatar/23463b99b62a72f26ed677cc556c44e8", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := TwitchTestSignupSetup(ts, &tokenCount, &userCount, code, twitchUser) + defer server.Close() + + u := performAuthorization(ts, "twitch", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "twitch@example.com", "Twitch user", "twitchTestId", "https://s.gravatar.com/avatar/23463b99b62a72f26ed677cc556c44e8") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalTwitchSuccessWhenMatchingToken() { + // name and avatar should be populated from Twitch API + ts.createUser("twitchTestId", "twitch@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + TwitchUser := `{"data":[{"id":"twitchTestId","login":"Twitch Test","display_name":"Twitch Test","type":"","broadcaster_type":"","description":"","profile_image_url":"https://s.gravatar.com/avatar/23463b99b62a72f26ed677cc556c44e8","offline_image_url":"","email":"twitch@example.com"}]}` + server := TwitchTestSignupSetup(ts, &tokenCount, &userCount, code, TwitchUser) + defer server.Close() + + u := performAuthorization(ts, "twitch", code, "invite_token") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "twitch@example.com", "Twitch Test", "twitchTestId", "https://s.gravatar.com/avatar/23463b99b62a72f26ed677cc556c44e8") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalTwitchErrorWhenNoMatchingToken() { + tokenCount, userCount := 0, 0 + code := "authcode" + TwitchUser := `{"data":[{"id":"1","login":"Twitch user","display_name":"Twitch user","type":"","broadcaster_type":"","description":"","profile_image_url":"https://s.gravatar.com/avatar/23463b99b62a72f26ed677cc556c44e8","offline_image_url":"","email":"twitch@example.com"}]}` + server := TwitchTestSignupSetup(ts, &tokenCount, &userCount, code, TwitchUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "twitch", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalTwitchErrorWhenWrongToken() { + ts.createUser("twitchTestId", "twitch@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := TwitchTestSignupSetup(ts, &tokenCount, &userCount, code, twitchUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "twitch", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalTwitchErrorWhenEmailDoesntMatch() { + ts.createUser("twitchTestId", "twitch@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := TwitchTestSignupSetup(ts, &tokenCount, &userCount, code, twitchUserWrongEmail) + defer server.Close() + + u := performAuthorization(ts, "twitch", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} diff --git a/auth_v2.169.0/internal/api/external_twitter_test.go b/auth_v2.169.0/internal/api/external_twitter_test.go new file mode 100644 index 0000000..d90d5d3 --- /dev/null +++ b/auth_v2.169.0/internal/api/external_twitter_test.go @@ -0,0 +1,42 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" +) + +func (ts *ExternalTestSuite) TestSignupExternalTwitter() { + server := TwitterTestSignupSetup(ts, nil, nil, "", "") + defer server.Close() + + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=twitter", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + + // Twitter uses OAuth1.0 protocol which only returns an oauth_token on the redirect + q := u.Query() + ts.Equal("twitter_oauth_token", q.Get("oauth_token")) +} + +func TwitterTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, user string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth/request_token": + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, "oauth_token=twitter_oauth_token&oauth_token_secret=twitter_oauth_token_secret&oauth_callback_confirmed=true") + default: + w.WriteHeader(500) + ts.Fail("unknown google oauth call %s", r.URL.Path) + } + })) + + ts.Config.External.Twitter.URL = server.URL + + return server +} diff --git a/auth_v2.169.0/internal/api/external_workos_test.go b/auth_v2.169.0/internal/api/external_workos_test.go new file mode 100644 index 0000000..eedd5b0 --- /dev/null +++ b/auth_v2.169.0/internal/api/external_workos_test.go @@ -0,0 +1,221 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + + jwt "github.com/golang-jwt/jwt/v5" +) + +const ( + workosUser string = `{"id":"test_prof_workos","first_name":"John","last_name":"Doe","email":"workos@example.com","connection_id":"test_conn_1","organization_id":"test_org_1","connection_type":"test","idp_id":"test_idp_1","object": "profile","raw_attributes": {}}` + workosUserWrongEmail string = `{"id":"test_prof_workos","first_name":"John","last_name":"Doe","email":"other@example.com","connection_id":"test_conn_1","organization_id":"test_org_1","connection_type":"test","idp_id":"test_idp_1","object": "profile","raw_attributes": {}}` + workosUserNoEmail string = `{"id":"test_prof_workos","first_name":"John","last_name":"Doe","connection_id":"test_conn_1","organization_id":"test_org_1","connection_type":"test","idp_id":"test_idp_1","object": "profile","raw_attributes": {}}` +) + +func (ts *ExternalTestSuite) TestSignupExternalWorkOSWithConnection() { + connection := "test_connection_id" + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("http://localhost/authorize?provider=workos&connection=%s", connection), nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.WorkOS.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.WorkOS.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("", q.Get("scope")) + ts.Equal(connection, q.Get("connection")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("workos", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func (ts *ExternalTestSuite) TestSignupExternalWorkOSWithOrganization() { + organization := "test_organization_id" + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("http://localhost/authorize?provider=workos&organization=%s", organization), nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.WorkOS.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.WorkOS.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("", q.Get("scope")) + ts.Equal(organization, q.Get("organization")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("workos", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func (ts *ExternalTestSuite) TestSignupExternalWorkOSWithProvider() { + provider := "test_provider" + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("http://localhost/authorize?provider=workos&workos_provider=%s", provider), nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.WorkOS.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.WorkOS.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + ts.Equal("", q.Get("scope")) + ts.Equal(provider, q.Get("provider")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("workos", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func WorkosTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, user string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/sso/token": + // WorkOS returns the user data along with the token. + *tokenCount++ + *userCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.WorkOS.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprintf(w, `{"access_token":"workos_token","expires_in":100000,"profile":%s}`, user) + default: + fmt.Printf("%s", r.URL.Path) + w.WriteHeader(500) + ts.Fail("unknown workos oauth call %s", r.URL.Path) + } + })) + + ts.Config.External.WorkOS.URL = server.URL + + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalWorkosAuthorizationCode() { + ts.Config.DisableSignup = false + + tokenCount, userCount := 0, 0 + code := "authcode" + server := WorkosTestSignupSetup(ts, &tokenCount, &userCount, code, workosUser) + defer server.Close() + + u := performAuthorization(ts, "workos", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "workos@example.com", "John Doe", "test_prof_workos", "") +} + +func (ts *ExternalTestSuite) TestSignupExternalWorkosDisableSignupErrorWhenNoUser() { + ts.Config.DisableSignup = true + + tokenCount, userCount := 0, 0 + code := "authcode" + server := WorkosTestSignupSetup(ts, &tokenCount, &userCount, code, workosUser) + defer server.Close() + + u := performAuthorization(ts, "workos", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", "workos@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalWorkosDisableSignupErrorWhenEmptyEmail() { + ts.Config.DisableSignup = true + + tokenCount, userCount := 0, 0 + code := "authcode" + server := WorkosTestSignupSetup(ts, &tokenCount, &userCount, code, workosUserNoEmail) + defer server.Close() + + u := performAuthorization(ts, "workos", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "workos@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalWorkosDisableSignupSuccessWithPrimaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("test_prof_workos", "workos@example.com", "John Doe", "", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := WorkosTestSignupSetup(ts, &tokenCount, &userCount, code, workosUser) + defer server.Close() + + u := performAuthorization(ts, "workos", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "workos@example.com", "John Doe", "test_prof_workos", "") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalWorkosSuccessWhenMatchingToken() { + ts.createUser("test_prof_workos", "workos@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := WorkosTestSignupSetup(ts, &tokenCount, &userCount, code, workosUser) + defer server.Close() + + u := performAuthorization(ts, "workos", code, "invite_token") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "workos@example.com", "John Doe", "test_prof_workos", "") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalWorkosErrorWhenNoMatchingToken() { + tokenCount, userCount := 0, 0 + code := "authcode" + server := WorkosTestSignupSetup(ts, &tokenCount, &userCount, code, workosUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "workos", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalWorkosErrorWhenWrongToken() { + ts.createUser("test_prof_workos", "workos@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := WorkosTestSignupSetup(ts, &tokenCount, &userCount, code, workosUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "workos", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalWorkosErrorWhenEmailDoesntMatch() { + ts.createUser("test_prof_workos", "workos@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := WorkosTestSignupSetup(ts, &tokenCount, &userCount, code, workosUserWrongEmail) + defer server.Close() + + u := performAuthorization(ts, "workos", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} diff --git a/auth_v2.169.0/internal/api/external_zoom_test.go b/auth_v2.169.0/internal/api/external_zoom_test.go new file mode 100644 index 0000000..ea3f15c --- /dev/null +++ b/auth_v2.169.0/internal/api/external_zoom_test.go @@ -0,0 +1,167 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + + jwt "github.com/golang-jwt/jwt/v5" +) + +const ( + zoomUser string = `{"id":"zoomUserId","first_name":"John","last_name": "Doe","email": "zoom@example.com","verified": 1,"pic_url":"http://example.com/avatar"}` + zoomUserWrongEmail string = `{"id":"zoomUserId","first_name":"John","last_name": "Doe","email": "other@example.com","verified": 1,"pic_url":"http://example.com/avatar"}` + zoomUserNoEmail string = `{"id":"zoomUserId","first_name":"John","last_name": "Doe","verified": 1,"pic_url":"http://example.com/avatar"}` +) + +func (ts *ExternalTestSuite) TestSignupExternalZoom() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=zoom", nil) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + ts.Equal(ts.Config.External.Zoom.RedirectURI, q.Get("redirect_uri")) + ts.Equal(ts.Config.External.Zoom.ClientID, []string{q.Get("client_id")}) + ts.Equal("code", q.Get("response_type")) + + claims := ExternalProviderClaims{} + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.ParseWithClaims(q.Get("state"), &claims, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + ts.Require().NoError(err) + + ts.Equal("zoom", claims.Provider) + ts.Equal(ts.Config.SiteURL, claims.SiteURL) +} + +func ZoomTestSignupSetup(ts *ExternalTestSuite, tokenCount *int, userCount *int, code string, user string) *httptest.Server { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth/token": + *tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.Zoom.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"zoom_token","expires_in":100000}`) + case "/v2/users/me": + *userCount++ + ts.Contains(r.Header, "Authorization") + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, user) + default: + w.WriteHeader(500) + ts.Fail("unknown zoom oauth call %s", r.URL.Path) + } + })) + + ts.Config.External.Zoom.URL = server.URL + + return server +} + +func (ts *ExternalTestSuite) TestSignupExternalZoomAuthorizationCode() { + ts.Config.DisableSignup = false + tokenCount, userCount := 0, 0 + code := "authcode" + server := ZoomTestSignupSetup(ts, &tokenCount, &userCount, code, zoomUser) + defer server.Close() + + u := performAuthorization(ts, "zoom", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "zoom@example.com", "John Doe", "zoomUserId", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestSignupExternalZoomDisableSignupErrorWhenNoUser() { + ts.Config.DisableSignup = true + + tokenCount, userCount := 0, 0 + code := "authcode" + server := ZoomTestSignupSetup(ts, &tokenCount, &userCount, code, zoomUser) + defer server.Close() + + u := performAuthorization(ts, "zoom", code, "") + + assertAuthorizationFailure(ts, u, "Signups not allowed for this instance", "access_denied", "zoom@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalZoomDisableSignupErrorWhenEmptyEmail() { + ts.Config.DisableSignup = true + + tokenCount, userCount := 0, 0 + code := "authcode" + server := ZoomTestSignupSetup(ts, &tokenCount, &userCount, code, zoomUserNoEmail) + defer server.Close() + + u := performAuthorization(ts, "zoom", code, "") + + assertAuthorizationFailure(ts, u, "Error getting user email from external provider", "server_error", "zoom@example.com") +} + +func (ts *ExternalTestSuite) TestSignupExternalZoomDisableSignupSuccessWithPrimaryEmail() { + ts.Config.DisableSignup = true + + ts.createUser("zoomUserId", "zoom@example.com", "John Doe", "http://example.com/avatar", "") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := ZoomTestSignupSetup(ts, &tokenCount, &userCount, code, zoomUser) + defer server.Close() + + u := performAuthorization(ts, "zoom", code, "") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "zoom@example.com", "John Doe", "zoomUserId", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalZoomSuccessWhenMatchingToken() { + ts.createUser("zoomUserId", "zoom@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := ZoomTestSignupSetup(ts, &tokenCount, &userCount, code, zoomUser) + defer server.Close() + + u := performAuthorization(ts, "zoom", code, "invite_token") + + assertAuthorizationSuccess(ts, u, tokenCount, userCount, "zoom@example.com", "John Doe", "zoomUserId", "http://example.com/avatar") +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalZoomErrorWhenNoMatchingToken() { + tokenCount, userCount := 0, 0 + code := "authcode" + server := ZoomTestSignupSetup(ts, &tokenCount, &userCount, code, zoomUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "zoom", "invite_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalZoomErrorWhenWrongToken() { + ts.createUser("zoomUserId", "zoom@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := ZoomTestSignupSetup(ts, &tokenCount, &userCount, code, zoomUser) + defer server.Close() + + w := performAuthorizationRequest(ts, "zoom", "wrong_token") + ts.Require().Equal(http.StatusNotFound, w.Code) +} + +func (ts *ExternalTestSuite) TestInviteTokenExternalZoomErrorWhenEmailDoesntMatch() { + ts.createUser("zoomUserId", "zoom@example.com", "", "", "invite_token") + + tokenCount, userCount := 0, 0 + code := "authcode" + server := ZoomTestSignupSetup(ts, &tokenCount, &userCount, code, zoomUserWrongEmail) + defer server.Close() + + u := performAuthorization(ts, "zoom", code, "invite_token") + + assertAuthorizationFailure(ts, u, "Invited email does not match emails from external provider", "invalid_request", "") +} diff --git a/auth_v2.169.0/internal/api/helpers.go b/auth_v2.169.0/internal/api/helpers.go new file mode 100644 index 0000000..8a9f326 --- /dev/null +++ b/auth_v2.169.0/internal/api/helpers.go @@ -0,0 +1,103 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/pkg/errors" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/security" + "github.com/supabase/auth/internal/utilities" +) + +func sendJSON(w http.ResponseWriter, status int, obj interface{}) error { + w.Header().Set("Content-Type", "application/json") + b, err := json.Marshal(obj) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("Error encoding json response: %v", obj)) + } + w.WriteHeader(status) + _, err = w.Write(b) + return err +} + +func isAdmin(u *models.User, config *conf.GlobalConfiguration) bool { + return config.JWT.Aud == u.Aud && u.HasRole(config.JWT.AdminGroupName) +} + +func (a *API) requestAud(ctx context.Context, r *http.Request) string { + config := a.config + // First check for an audience in the header + if aud := r.Header.Get(audHeaderName); aud != "" { + return aud + } + + // Then check the token + claims := getClaims(ctx) + + if claims != nil { + aud, _ := claims.GetAudience() + if len(aud) != 0 && aud[0] != "" { + return aud[0] + } + } + + // Finally, return the default if none of the above methods are successful + return config.JWT.Aud +} + +func isStringInSlice(checkValue string, list []string) bool { + for _, val := range list { + if val == checkValue { + return true + } + } + return false +} + +type RequestParams interface { + AdminUserParams | + CreateSSOProviderParams | + EnrollFactorParams | + GenerateLinkParams | + IdTokenGrantParams | + InviteParams | + OtpParams | + PKCEGrantParams | + PasswordGrantParams | + RecoverParams | + RefreshTokenGrantParams | + ResendConfirmationParams | + SignupParams | + SingleSignOnParams | + SmsParams | + UserUpdateParams | + VerifyFactorParams | + VerifyParams | + adminUserUpdateFactorParams | + adminUserDeleteParams | + security.GotrueRequest | + ChallengeFactorParams | + struct { + Email string `json:"email"` + Phone string `json:"phone"` + } | + struct { + Email string `json:"email"` + } +} + +// retrieveRequestParams is a generic method that unmarshals the request body into the params struct provided +func retrieveRequestParams[A RequestParams](r *http.Request, params *A) error { + body, err := utilities.GetBodyBytes(r) + if err != nil { + return internalServerError("Could not read body into byte slice").WithInternalError(err) + } + if err := json.Unmarshal(body, params); err != nil { + return badRequestError(ErrorCodeBadJSON, "Could not parse request body as JSON: %v", err) + } + return nil +} diff --git a/auth_v2.169.0/internal/api/helpers_test.go b/auth_v2.169.0/internal/api/helpers_test.go new file mode 100644 index 0000000..29070e8 --- /dev/null +++ b/auth_v2.169.0/internal/api/helpers_test.go @@ -0,0 +1,151 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/conf" +) + +func TestIsValidCodeChallenge(t *testing.T) { + cases := []struct { + challenge string + isValid bool + expectedError error + }{ + { + challenge: "invalid", + isValid: false, + expectedError: badRequestError(ErrorCodeValidationFailed, "code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength), + }, + { + challenge: "codechallengecontainsinvalidcharacterslike@$^&*", + isValid: false, + expectedError: badRequestError(ErrorCodeValidationFailed, "code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes"), + }, + { + challenge: "validchallengevalidchallengevalidchallengevalidchallenge", + isValid: true, + expectedError: nil, + }, + } + + for _, c := range cases { + valid, err := isValidCodeChallenge(c.challenge) + require.Equal(t, c.isValid, valid) + require.Equal(t, c.expectedError, err) + } +} + +func TestIsValidPKCEParams(t *testing.T) { + cases := []struct { + challengeMethod string + challenge string + expected error + }{ + { + challengeMethod: "", + challenge: "", + expected: nil, + }, + { + challengeMethod: "test", + challenge: "testtesttesttesttesttesttesttesttesttesttesttesttesttesttesttest", + expected: nil, + }, + { + challengeMethod: "test", + challenge: "", + expected: badRequestError(ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage), + }, + { + challengeMethod: "", + challenge: "test", + expected: badRequestError(ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage), + }, + } + + for i, c := range cases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + err := validatePKCEParams(c.challengeMethod, c.challenge) + require.Equal(t, c.expected, err) + }) + } +} + +func TestRequestAud(ts *testing.T) { + mockAPI := API{ + config: &conf.GlobalConfiguration{ + JWT: conf.JWTConfiguration{ + Aud: "authenticated", + Secret: "test-secret", + }, + }, + } + + cases := []struct { + desc string + headers map[string]string + payload map[string]interface{} + expectedAud string + }{ + { + desc: "Valid audience slice", + headers: map[string]string{ + audHeaderName: "my_custom_aud", + }, + payload: map[string]interface{}{ + "aud": "authenticated", + }, + expectedAud: "my_custom_aud", + }, + { + desc: "Valid custom audience", + payload: map[string]interface{}{ + "aud": "my_custom_aud", + }, + expectedAud: "my_custom_aud", + }, + { + desc: "Invalid audience", + payload: map[string]interface{}{ + "aud": "", + }, + expectedAud: mockAPI.config.JWT.Aud, + }, + { + desc: "Missing audience", + payload: map[string]interface{}{ + "sub": "d6044b6e-b0ec-4efe-a055-0d2d6ff1dbd8", + }, + expectedAud: mockAPI.config.JWT.Aud, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func(t *testing.T) { + claims := jwt.MapClaims(c.payload) + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signed, err := token.SignedString([]byte(mockAPI.config.JWT.Secret)) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer: %s", signed)) + for k, v := range c.headers { + req.Header.Set(k, v) + } + + // set the token in the request context for requestAud + ctx, err := mockAPI.parseJWTClaims(signed, req) + require.NoError(t, err) + aud := mockAPI.requestAud(ctx, req) + require.Equal(t, c.expectedAud, aud) + }) + } + +} diff --git a/auth_v2.169.0/internal/api/hooks.go b/auth_v2.169.0/internal/api/hooks.go new file mode 100644 index 0000000..2cf99cd --- /dev/null +++ b/auth_v2.169.0/internal/api/hooks.go @@ -0,0 +1,405 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "mime" + "net" + "net/http" + "strings" + "time" + + "github.com/gofrs/uuid" + "github.com/sirupsen/logrus" + standardwebhooks "github.com/standard-webhooks/standard-webhooks/libraries/go" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/hooks" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" +) + +const ( + DefaultHTTPHookTimeout = 5 * time.Second + DefaultHTTPHookRetries = 3 + HTTPHookBackoffDuration = 2 * time.Second + PayloadLimit = 200 * 1024 // 200KB +) + +func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) { + db := a.db.WithContext(ctx) + + request, err := json.Marshal(input) + if err != nil { + panic(err) + } + + var response []byte + invokeHookFunc := func(tx *storage.Connection) error { + // We rely on Postgres timeouts to ensure the function doesn't overrun + if terr := tx.RawQuery(fmt.Sprintf("set local statement_timeout TO '%d';", hooks.DefaultTimeout)).Exec(); terr != nil { + return terr + } + + if terr := tx.RawQuery(fmt.Sprintf("select %s(?);", hookConfig.HookName), request).First(&response); terr != nil { + return terr + } + + // reset the timeout + if terr := tx.RawQuery("set local statement_timeout TO default;").Exec(); terr != nil { + return terr + } + + return nil + } + + if tx != nil { + if err := invokeHookFunc(tx); err != nil { + return nil, err + } + } else { + if err := db.Transaction(invokeHookFunc); err != nil { + return nil, err + } + } + + if err := json.Unmarshal(response, output); err != nil { + return response, err + } + + return response, nil +} + +func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointConfiguration, input any) ([]byte, error) { + ctx := r.Context() + client := http.Client{ + Timeout: DefaultHTTPHookTimeout, + } + ctx, cancel := context.WithTimeout(ctx, DefaultHTTPHookTimeout) + defer cancel() + + log := observability.GetLogEntry(r).Entry + requestURL := hookConfig.URI + hookLog := log.WithFields(logrus.Fields{ + "component": "auth_hook", + "url": requestURL, + }) + + inputPayload, err := json.Marshal(input) + if err != nil { + return nil, err + } + for i := 0; i < DefaultHTTPHookRetries; i++ { + if i == 0 { + hookLog.Debugf("invocation attempt: %d", i) + } else { + hookLog.Infof("invocation attempt: %d", i) + } + msgID := uuid.Must(uuid.NewV4()) + currentTime := time.Now() + signatureList, err := generateSignatures(hookConfig.HTTPHookSecrets, msgID, currentTime, inputPayload) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewBuffer(inputPayload)) + if err != nil { + panic("Failed to make request object") + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("webhook-id", msgID.String()) + req.Header.Set("webhook-timestamp", fmt.Sprintf("%d", currentTime.Unix())) + req.Header.Set("webhook-signature", strings.Join(signatureList, ", ")) + // By default, Go Client sets encoding to gzip, which does not carry a content length header. + req.Header.Set("Accept-Encoding", "identity") + + rsp, err := client.Do(req) + if err != nil && errors.Is(err, context.DeadlineExceeded) { + return nil, unprocessableEntityError(ErrorCodeHookTimeout, fmt.Sprintf("Failed to reach hook within maximum time of %f seconds", DefaultHTTPHookTimeout.Seconds())) + + } else if err != nil { + if terr, ok := err.(net.Error); ok && terr.Timeout() || i < DefaultHTTPHookRetries-1 { + hookLog.Errorf("Request timed out for attempt %d with err %s", i, err) + time.Sleep(HTTPHookBackoffDuration) + continue + } else if i == DefaultHTTPHookRetries-1 { + return nil, unprocessableEntityError(ErrorCodeHookTimeoutAfterRetry, "Failed to reach hook after maximum retries") + } else { + return nil, internalServerError("Failed to trigger auth hook, error making HTTP request").WithInternalError(err) + } + } + + defer rsp.Body.Close() + + switch rsp.StatusCode { + case http.StatusOK, http.StatusNoContent, http.StatusAccepted: + // Header.Get is case insensitive + contentType := rsp.Header.Get("Content-Type") + if contentType == "" { + return nil, badRequestError(ErrorCodeHookPayloadInvalidContentType, "Invalid Content-Type: Missing Content-Type header") + } + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil { + return nil, badRequestError(ErrorCodeHookPayloadInvalidContentType, fmt.Sprintf("Invalid Content-Type header: %s", err.Error())) + } + if mediaType != "application/json" { + return nil, badRequestError(ErrorCodeHookPayloadInvalidContentType, "Invalid JSON response. Received content-type: "+contentType) + } + if rsp.Body == nil { + return nil, nil + } + limitedReader := io.LimitedReader{R: rsp.Body, N: PayloadLimit} + body, err := io.ReadAll(&limitedReader) + if err != nil { + return nil, err + } + if limitedReader.N <= 0 { + // check if the response body still has excess bytes to be read + if n, _ := rsp.Body.Read(make([]byte, 1)); n > 0 { + return nil, unprocessableEntityError(ErrorCodeHookPayloadOverSizeLimit, fmt.Sprintf("Payload size exceeded size limit of %d bytes", PayloadLimit)) + } + } + return body, nil + case http.StatusTooManyRequests, http.StatusServiceUnavailable: + retryAfterHeader := rsp.Header.Get("retry-after") + // Check for truthy values to allow for flexibility to switch to time duration + if retryAfterHeader != "" { + continue + } + return nil, internalServerError("Service currently unavailable due to hook") + case http.StatusBadRequest: + return nil, internalServerError("Invalid payload sent to hook") + case http.StatusUnauthorized: + return nil, internalServerError("Hook requires authorization token") + default: + return nil, internalServerError("Unexpected status code returned from hook: %d", rsp.StatusCode) + } + } + return nil, nil +} + +// invokePostgresHook invokes the hook code. conn can be nil, in which case a new +// transaction is opened. If calling invokeHook within a transaction, always +// pass the current transaction, as pool-exhaustion deadlocks are very easy to +// trigger. +func (a *API) invokeHook(conn *storage.Connection, r *http.Request, input, output any) error { + var err error + var response []byte + + switch input.(type) { + case *hooks.SendSMSInput: + hookOutput, ok := output.(*hooks.SendSMSOutput) + if !ok { + panic("output should be *hooks.SendSMSOutput") + } + if response, err = a.runHook(r, conn, a.config.Hook.SendSMS, input, output); err != nil { + return err + } + if err := json.Unmarshal(response, hookOutput); err != nil { + return internalServerError("Error unmarshaling Send SMS output.").WithInternalError(err) + } + if hookOutput.IsError() { + httpCode := hookOutput.HookError.HTTPCode + + if httpCode == 0 { + httpCode = http.StatusInternalServerError + } + httpError := &HTTPError{ + HTTPStatus: httpCode, + Message: hookOutput.HookError.Message, + } + return httpError.WithInternalError(&hookOutput.HookError) + } + return nil + case *hooks.SendEmailInput: + hookOutput, ok := output.(*hooks.SendEmailOutput) + if !ok { + panic("output should be *hooks.SendEmailOutput") + } + if response, err = a.runHook(r, conn, a.config.Hook.SendEmail, input, output); err != nil { + return err + } + if err := json.Unmarshal(response, hookOutput); err != nil { + return internalServerError("Error unmarshaling Send Email output.").WithInternalError(err) + } + if hookOutput.IsError() { + httpCode := hookOutput.HookError.HTTPCode + + if httpCode == 0 { + httpCode = http.StatusInternalServerError + } + + httpError := &HTTPError{ + HTTPStatus: httpCode, + Message: hookOutput.HookError.Message, + } + + return httpError.WithInternalError(&hookOutput.HookError) + } + return nil + case *hooks.MFAVerificationAttemptInput: + hookOutput, ok := output.(*hooks.MFAVerificationAttemptOutput) + if !ok { + panic("output should be *hooks.MFAVerificationAttemptOutput") + } + if response, err = a.runHook(r, conn, a.config.Hook.MFAVerificationAttempt, input, output); err != nil { + return err + } + if err := json.Unmarshal(response, hookOutput); err != nil { + return internalServerError("Error unmarshaling MFA Verification Attempt output.").WithInternalError(err) + } + if hookOutput.IsError() { + httpCode := hookOutput.HookError.HTTPCode + + if httpCode == 0 { + httpCode = http.StatusInternalServerError + } + + httpError := &HTTPError{ + HTTPStatus: httpCode, + Message: hookOutput.HookError.Message, + } + + return httpError.WithInternalError(&hookOutput.HookError) + } + return nil + case *hooks.PasswordVerificationAttemptInput: + hookOutput, ok := output.(*hooks.PasswordVerificationAttemptOutput) + if !ok { + panic("output should be *hooks.PasswordVerificationAttemptOutput") + } + + if response, err = a.runHook(r, conn, a.config.Hook.PasswordVerificationAttempt, input, output); err != nil { + return err + } + if err := json.Unmarshal(response, hookOutput); err != nil { + return internalServerError("Error unmarshaling Password Verification Attempt output.").WithInternalError(err) + } + if hookOutput.IsError() { + httpCode := hookOutput.HookError.HTTPCode + + if httpCode == 0 { + httpCode = http.StatusInternalServerError + } + + httpError := &HTTPError{ + HTTPStatus: httpCode, + Message: hookOutput.HookError.Message, + } + + return httpError.WithInternalError(&hookOutput.HookError) + } + + return nil + case *hooks.CustomAccessTokenInput: + hookOutput, ok := output.(*hooks.CustomAccessTokenOutput) + if !ok { + panic("output should be *hooks.CustomAccessTokenOutput") + } + if response, err = a.runHook(r, conn, a.config.Hook.CustomAccessToken, input, output); err != nil { + return err + } + if err := json.Unmarshal(response, hookOutput); err != nil { + return internalServerError("Error unmarshaling Custom Access Token output.").WithInternalError(err) + } + + if hookOutput.IsError() { + httpCode := hookOutput.HookError.HTTPCode + + if httpCode == 0 { + httpCode = http.StatusInternalServerError + } + + httpError := &HTTPError{ + HTTPStatus: httpCode, + Message: hookOutput.HookError.Message, + } + + return httpError.WithInternalError(&hookOutput.HookError) + } + if err := validateTokenClaims(hookOutput.Claims); err != nil { + httpCode := hookOutput.HookError.HTTPCode + + if httpCode == 0 { + httpCode = http.StatusInternalServerError + } + httpError := &HTTPError{ + HTTPStatus: httpCode, + Message: err.Error(), + } + + return httpError + } + return nil + } + return nil +} + +func (a *API) runHook(r *http.Request, conn *storage.Connection, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) { + ctx := r.Context() + + logEntry := observability.GetLogEntry(r) + hookStart := time.Now() + + var response []byte + var err error + + switch { + case strings.HasPrefix(hookConfig.URI, "http:") || strings.HasPrefix(hookConfig.URI, "https:"): + response, err = a.runHTTPHook(r, hookConfig, input) + case strings.HasPrefix(hookConfig.URI, "pg-functions:"): + response, err = a.runPostgresHook(ctx, conn, hookConfig, input, output) + default: + return nil, fmt.Errorf("unsupported protocol: %q only postgres hooks and HTTPS functions are supported at the moment", hookConfig.URI) + } + + duration := time.Since(hookStart) + + if err != nil { + logEntry.Entry.WithFields(logrus.Fields{ + "action": "run_hook", + "hook": hookConfig.URI, + "success": false, + "duration": duration.Microseconds(), + }).WithError(err).Warn("Hook errored out") + + return nil, internalServerError("Error running hook URI: %v", hookConfig.URI).WithInternalError(err) + } + + logEntry.Entry.WithFields(logrus.Fields{ + "action": "run_hook", + "hook": hookConfig.URI, + "success": true, + "duration": duration.Microseconds(), + }).WithError(err).Info("Hook ran successfully") + + return response, nil +} + +func generateSignatures(secrets []string, msgID uuid.UUID, currentTime time.Time, inputPayload []byte) ([]string, error) { + SymmetricSignaturePrefix := "v1," + // TODO(joel): Handle asymmetric case once library has been upgraded + var signatureList []string + for _, secret := range secrets { + if strings.HasPrefix(secret, SymmetricSignaturePrefix) { + trimmedSecret := strings.TrimPrefix(secret, SymmetricSignaturePrefix) + wh, err := standardwebhooks.NewWebhook(trimmedSecret) + if err != nil { + return nil, err + } + signature, err := wh.Sign(msgID.String(), currentTime, inputPayload) + if err != nil { + return nil, err + } + signatureList = append(signatureList, signature) + } else { + return nil, errors.New("invalid signature format") + } + } + return signatureList, nil +} diff --git a/auth_v2.169.0/internal/api/hooks_test.go b/auth_v2.169.0/internal/api/hooks_test.go new file mode 100644 index 0000000..c78ce5f --- /dev/null +++ b/auth_v2.169.0/internal/api/hooks_test.go @@ -0,0 +1,287 @@ +package api + +import ( + "encoding/json" + "net/http" + "testing" + + "net/http/httptest" + + "github.com/pkg/errors" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/hooks" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" + + "gopkg.in/h2non/gock.v1" +) + +var handleApiRequest func(*http.Request) (*http.Response, error) + +type HooksTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration + TestUser *models.User +} + +type MockHttpClient struct { + mock.Mock +} + +func (m *MockHttpClient) Do(req *http.Request) (*http.Response, error) { + return handleApiRequest(req) +} + +func TestHooks(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &HooksTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *HooksTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + u, err := models.NewUser("123456789", "testemail@gmail.com", "securetestpassword", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") + ts.TestUser = u +} + +func (ts *HooksTestSuite) TestRunHTTPHook() { + // setup mock requests for hooks + defer gock.OffAll() + + input := hooks.SendSMSInput{ + User: ts.TestUser, + SMS: hooks.SMS{ + OTP: "123456", + }, + } + testURL := "http://localhost:54321/functions/v1/custom-sms-sender" + ts.Config.Hook.SendSMS.URI = testURL + + unsuccessfulResponse := hooks.AuthHookError{ + HTTPCode: http.StatusUnprocessableEntity, + Message: "test error", + } + + testCases := []struct { + description string + expectError bool + mockResponse hooks.AuthHookError + }{ + { + description: "Hook returns success", + expectError: false, + mockResponse: hooks.AuthHookError{}, + }, + { + description: "Hook returns error", + expectError: true, + mockResponse: unsuccessfulResponse, + }, + } + + gock.New(ts.Config.Hook.SendSMS.URI). + Post("/"). + MatchType("json"). + Reply(http.StatusOK). + JSON(hooks.SendSMSOutput{}) + + gock.New(ts.Config.Hook.SendSMS.URI). + Post("/"). + MatchType("json"). + Reply(http.StatusUnprocessableEntity). + JSON(hooks.SendSMSOutput{HookError: unsuccessfulResponse}) + + for _, tc := range testCases { + ts.Run(tc.description, func() { + req, _ := http.NewRequest("POST", ts.Config.Hook.SendSMS.URI, nil) + body, err := ts.API.runHTTPHook(req, ts.Config.Hook.SendSMS, &input) + + if !tc.expectError { + require.NoError(ts.T(), err) + } else { + require.Error(ts.T(), err) + if body != nil { + var output hooks.SendSMSOutput + require.NoError(ts.T(), json.Unmarshal(body, &output)) + require.Equal(ts.T(), unsuccessfulResponse.HTTPCode, output.HookError.HTTPCode) + require.Equal(ts.T(), unsuccessfulResponse.Message, output.HookError.Message) + } + } + }) + } + require.True(ts.T(), gock.IsDone()) +} + +func (ts *HooksTestSuite) TestShouldRetryWithRetryAfterHeader() { + defer gock.OffAll() + + input := hooks.SendSMSInput{ + User: ts.TestUser, + SMS: hooks.SMS{ + OTP: "123456", + }, + } + testURL := "http://localhost:54321/functions/v1/custom-sms-sender" + ts.Config.Hook.SendSMS.URI = testURL + + gock.New(testURL). + Post("/"). + MatchType("json"). + Reply(http.StatusTooManyRequests). + SetHeader("retry-after", "true").SetHeader("content-type", "application/json") + + // Simulate an additional response for the retry attempt + gock.New(testURL). + Post("/"). + MatchType("json"). + Reply(http.StatusOK). + JSON(hooks.SendSMSOutput{}).SetHeader("content-type", "application/json") + + // Simulate the original HTTP request which triggered the hook + req, err := http.NewRequest("POST", "http://localhost:9998/otp", nil) + require.NoError(ts.T(), err) + + body, err := ts.API.runHTTPHook(req, ts.Config.Hook.SendSMS, &input) + require.NoError(ts.T(), err) + + var output hooks.SendSMSOutput + err = json.Unmarshal(body, &output) + require.NoError(ts.T(), err, "Unmarshal should not fail") + + // Ensure that all expected HTTP interactions (mocks) have been called + require.True(ts.T(), gock.IsDone(), "Expected all mocks to have been called including retry") +} + +func (ts *HooksTestSuite) TestShouldReturnErrorForNonJSONContentType() { + defer gock.OffAll() + + input := hooks.SendSMSInput{ + User: ts.TestUser, + SMS: hooks.SMS{ + OTP: "123456", + }, + } + testURL := "http://localhost:54321/functions/v1/custom-sms-sender" + ts.Config.Hook.SendSMS.URI = testURL + + gock.New(testURL). + Post("/"). + MatchType("json"). + Reply(http.StatusOK). + SetHeader("content-type", "text/plain") + + req, err := http.NewRequest("POST", "http://localhost:9999/otp", nil) + require.NoError(ts.T(), err) + + _, err = ts.API.runHTTPHook(req, ts.Config.Hook.SendSMS, &input) + require.Error(ts.T(), err, "Expected an error due to wrong content type") + require.Contains(ts.T(), err.Error(), "Invalid JSON response.") + + require.True(ts.T(), gock.IsDone(), "Expected all mocks to have been called") +} + +func (ts *HooksTestSuite) TestInvokeHookIntegration() { + // We use the Send Email Hook as illustration + defer gock.OffAll() + hookFunctionSQL := ` + create or replace function invoke_test(input jsonb) + returns json as $$ + begin + return input; + end; $$ language plpgsql;` + require.NoError(ts.T(), ts.API.db.RawQuery(hookFunctionSQL).Exec()) + + testHTTPUri := "http://myauthservice.com/signup" + testHTTPSUri := "https://myauthservice.com/signup" + testPGUri := "pg-functions://postgres/auth/invoke_test" + successOutput := map[string]interface{}{} + authEndpoint := "https://app.myapp.com/otp" + gock.New(testHTTPUri). + Post("/"). + MatchType("json"). + Reply(http.StatusOK). + JSON(successOutput).SetHeader("content-type", "application/json") + + gock.New(testHTTPSUri). + Post("/"). + MatchType("json"). + Reply(http.StatusOK). + JSON(successOutput).SetHeader("content-type", "application/json") + + tests := []struct { + description string + conn *storage.Connection + request *http.Request + input any + output any + uri string + expectedError error + }{ + { + description: "HTTP endpoint success", + conn: nil, + request: httptest.NewRequest("POST", authEndpoint, nil), + input: &hooks.SendEmailInput{}, + output: &hooks.SendEmailOutput{}, + uri: testHTTPUri, + }, + { + description: "HTTPS endpoint success", + conn: nil, + request: httptest.NewRequest("POST", authEndpoint, nil), + input: &hooks.SendEmailInput{}, + output: &hooks.SendEmailOutput{}, + uri: testHTTPSUri, + }, + { + description: "PostgreSQL function success", + conn: ts.API.db, + request: httptest.NewRequest("POST", authEndpoint, nil), + input: &hooks.SendEmailInput{}, + output: &hooks.SendEmailOutput{}, + uri: testPGUri, + }, + { + description: "Unsupported protocol error", + conn: nil, + request: httptest.NewRequest("POST", authEndpoint, nil), + input: &hooks.SendEmailInput{}, + output: &hooks.SendEmailOutput{}, + uri: "ftp://example.com/path", + expectedError: errors.New("unsupported protocol: \"ftp://example.com/path\" only postgres hooks and HTTPS functions are supported at the moment"), + }, + } + + var err error + for _, tc := range tests { + // Set up hook config + ts.Config.Hook.SendEmail.Enabled = true + ts.Config.Hook.SendEmail.URI = tc.uri + require.NoError(ts.T(), ts.Config.Hook.SendEmail.PopulateExtensibilityPoint()) + + ts.Run(tc.description, func() { + err = ts.API.invokeHook(tc.conn, tc.request, tc.input, tc.output) + if tc.expectedError != nil { + require.EqualError(ts.T(), err, tc.expectedError.Error()) + } else { + require.NoError(ts.T(), err) + } + }) + + } + // Ensure that all expected HTTP interactions (mocks) have been called + require.True(ts.T(), gock.IsDone(), "Expected all mocks to have been called including retry") +} diff --git a/auth_v2.169.0/internal/api/identity.go b/auth_v2.169.0/internal/api/identity.go new file mode 100644 index 0000000..53cef86 --- /dev/null +++ b/auth_v2.169.0/internal/api/identity.go @@ -0,0 +1,155 @@ +package api + +import ( + "context" + "net/http" + + "github.com/fatih/structs" + "github.com/go-chi/chi/v5" + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + + claims := getClaims(ctx) + if claims == nil { + return internalServerError("Could not read claims") + } + + identityID, err := uuid.FromString(chi.URLParam(r, "identity_id")) + if err != nil { + return notFoundError(ErrorCodeValidationFailed, "identity_id must be an UUID") + } + + aud := a.requestAud(ctx, r) + audienceFromClaims, _ := claims.GetAudience() + if len(audienceFromClaims) == 0 || aud != audienceFromClaims[0] { + return forbiddenError(ErrorCodeUnexpectedAudience, "Token audience doesn't match request audience") + } + + user := getUser(ctx) + if len(user.Identities) <= 1 { + return unprocessableEntityError(ErrorCodeSingleIdentityNotDeletable, "User must have at least 1 identity after unlinking") + } + var identityToBeDeleted *models.Identity + for i := range user.Identities { + identity := user.Identities[i] + if identity.ID == identityID { + identityToBeDeleted = &identity + break + } + } + if identityToBeDeleted == nil { + return unprocessableEntityError(ErrorCodeIdentityNotFound, "Identity doesn't exist") + } + + err = a.db.Transaction(func(tx *storage.Connection) error { + if terr := models.NewAuditLogEntry(r, tx, user, models.IdentityUnlinkAction, "", map[string]interface{}{ + "identity_id": identityToBeDeleted.ID, + "provider": identityToBeDeleted.Provider, + "provider_id": identityToBeDeleted.ProviderID, + }); terr != nil { + return internalServerError("Error recording audit log entry").WithInternalError(terr) + } + if terr := tx.Destroy(identityToBeDeleted); terr != nil { + return internalServerError("Database error deleting identity").WithInternalError(terr) + } + + switch identityToBeDeleted.Provider { + case "phone": + user.PhoneConfirmedAt = nil + if terr := user.SetPhone(tx, ""); terr != nil { + return internalServerError("Database error updating user phone").WithInternalError(terr) + } + if terr := tx.UpdateOnly(user, "phone_confirmed_at"); terr != nil { + return internalServerError("Database error updating user phone").WithInternalError(terr) + } + default: + if terr := user.UpdateUserEmailFromIdentities(tx); terr != nil { + if models.IsUniqueConstraintViolatedError(terr) { + return unprocessableEntityError(ErrorCodeEmailConflictIdentityNotDeletable, "Unable to unlink identity due to email conflict").WithInternalError(terr) + } + return internalServerError("Database error updating user email").WithInternalError(terr) + } + } + if terr := user.UpdateAppMetaDataProviders(tx); terr != nil { + return internalServerError("Database error updating user providers").WithInternalError(terr) + } + return nil + }) + if err != nil { + return err + } + + return sendJSON(w, http.StatusOK, map[string]interface{}{}) +} + +func (a *API) LinkIdentity(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + user := getUser(ctx) + rurl, err := a.GetExternalProviderRedirectURL(w, r, user) + if err != nil { + return err + } + skipHTTPRedirect := r.URL.Query().Get("skip_http_redirect") == "true" + if skipHTTPRedirect { + return sendJSON(w, http.StatusOK, map[string]interface{}{ + "url": rurl, + }) + } + http.Redirect(w, r, rurl, http.StatusFound) + return nil +} + +func (a *API) linkIdentityToUser(r *http.Request, ctx context.Context, tx *storage.Connection, userData *provider.UserProvidedData, providerType string) (*models.User, error) { + targetUser := getTargetUser(ctx) + identity, terr := models.FindIdentityByIdAndProvider(tx, userData.Metadata.Subject, providerType) + if terr != nil { + if !models.IsNotFoundError(terr) { + return nil, internalServerError("Database error finding identity for linking").WithInternalError(terr) + } + } + if identity != nil { + if identity.UserID == targetUser.ID { + return nil, unprocessableEntityError(ErrorCodeIdentityAlreadyExists, "Identity is already linked") + } + return nil, unprocessableEntityError(ErrorCodeIdentityAlreadyExists, "Identity is already linked to another user") + } + if _, terr := a.createNewIdentity(tx, targetUser, providerType, structs.Map(userData.Metadata)); terr != nil { + return nil, terr + } + + if targetUser.GetEmail() == "" { + if terr := targetUser.UpdateUserEmailFromIdentities(tx); terr != nil { + if models.IsUniqueConstraintViolatedError(terr) { + return nil, badRequestError(ErrorCodeEmailExists, DuplicateEmailMsg) + } + return nil, terr + } + if !userData.Metadata.EmailVerified { + if terr := a.sendConfirmation(r, tx, targetUser, models.ImplicitFlow); terr != nil { + return nil, terr + } + return nil, storage.NewCommitWithError(unprocessableEntityError(ErrorCodeEmailNotConfirmed, "Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType)) + } + if terr := targetUser.Confirm(tx); terr != nil { + return nil, terr + } + + if targetUser.IsAnonymous { + targetUser.IsAnonymous = false + if terr := tx.UpdateOnly(targetUser, "is_anonymous"); terr != nil { + return nil, terr + } + } + } + + if terr := targetUser.UpdateAppMetaDataProviders(tx); terr != nil { + return nil, terr + } + return targetUser, nil +} diff --git a/auth_v2.169.0/internal/api/identity_test.go b/auth_v2.169.0/internal/api/identity_test.go new file mode 100644 index 0000000..999559e --- /dev/null +++ b/auth_v2.169.0/internal/api/identity_test.go @@ -0,0 +1,227 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +type IdentityTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestIdentity(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + ts := &IdentityTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + suite.Run(t, ts) +} + +func (ts *IdentityTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + // Create user + u, err := models.NewUser("", "one@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") + require.NoError(ts.T(), u.Confirm(ts.API.db)) + + // Create identity + i, err := models.NewIdentity(u, "email", map[string]interface{}{ + "sub": u.ID.String(), + "email": u.GetEmail(), + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(i)) + + // Create user with 2 identities + u, err = models.NewUser("123456789", "two@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") + require.NoError(ts.T(), u.Confirm(ts.API.db)) + require.NoError(ts.T(), u.ConfirmPhone(ts.API.db)) + + i, err = models.NewIdentity(u, "email", map[string]interface{}{ + "sub": u.ID.String(), + "email": u.GetEmail(), + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(i)) + + i2, err := models.NewIdentity(u, "phone", map[string]interface{}{ + "sub": u.ID.String(), + "phone": u.GetPhone(), + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(i2)) +} + +func (ts *IdentityTestSuite) TestLinkIdentityToUser() { + u, err := models.FindUserByEmailAndAudience(ts.API.db, "one@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + ctx := withTargetUser(context.Background(), u) + + // link a valid identity + testValidUserData := &provider.UserProvidedData{ + Metadata: &provider.Claims{ + Subject: "test_subject", + }, + } + // request is just used as a placeholder in the function + r := httptest.NewRequest(http.MethodGet, "/identities", nil) + u, err = ts.API.linkIdentityToUser(r, ctx, ts.API.db, testValidUserData, "test") + require.NoError(ts.T(), err) + + // load associated identities for the user + ts.API.db.Load(u, "Identities") + require.Len(ts.T(), u.Identities, 2) + require.Equal(ts.T(), u.AppMetaData["provider"], "email") + require.Equal(ts.T(), u.AppMetaData["providers"], []string{"email", "test"}) + + // link an already existing identity + testExistingUserData := &provider.UserProvidedData{ + Metadata: &provider.Claims{ + Subject: u.ID.String(), + }, + } + u, err = ts.API.linkIdentityToUser(r, ctx, ts.API.db, testExistingUserData, "email") + require.ErrorIs(ts.T(), err, unprocessableEntityError(ErrorCodeIdentityAlreadyExists, "Identity is already linked")) + require.Nil(ts.T(), u) +} + +func (ts *IdentityTestSuite) TestUnlinkIdentityError() { + ts.Config.Security.ManualLinkingEnabled = true + userWithOneIdentity, err := models.FindUserByEmailAndAudience(ts.API.db, "one@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + userWithTwoIdentities, err := models.FindUserByEmailAndAudience(ts.API.db, "two@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + cases := []struct { + desc string + user *models.User + identityId uuid.UUID + expectedError *HTTPError + }{ + { + desc: "User must have at least 1 identity after unlinking", + user: userWithOneIdentity, + identityId: userWithOneIdentity.Identities[0].ID, + expectedError: unprocessableEntityError(ErrorCodeSingleIdentityNotDeletable, "User must have at least 1 identity after unlinking"), + }, + { + desc: "Identity doesn't exist", + user: userWithTwoIdentities, + identityId: uuid.Must(uuid.NewV4()), + expectedError: unprocessableEntityError(ErrorCodeIdentityNotFound, "Identity doesn't exist"), + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + token := ts.generateAccessTokenAndSession(c.user) + req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("/user/identities/%s", c.identityId), nil) + require.NoError(ts.T(), err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), c.expectedError.HTTPStatus, w.Code) + + var data HTTPError + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.Equal(ts.T(), c.expectedError.Message, data.Message) + }) + } +} + +func (ts *IdentityTestSuite) TestUnlinkIdentity() { + ts.Config.Security.ManualLinkingEnabled = true + + // we want to test 2 cases here: unlinking a phone identity and email identity from a user + cases := []struct { + desc string + // the provider to be unlinked + provider string + // the remaining provider that should be linked to the user + providerRemaining string + }{ + { + desc: "Unlink phone identity successfully", + provider: "phone", + providerRemaining: "email", + }, + { + desc: "Unlink email identity successfully", + provider: "email", + providerRemaining: "phone", + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + // teardown and reset the state of the db to prevent running into errors + ts.SetupTest() + u, err := models.FindUserByEmailAndAudience(ts.API.db, "two@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + identity, err := models.FindIdentityByIdAndProvider(ts.API.db, u.ID.String(), c.provider) + require.NoError(ts.T(), err) + + token := ts.generateAccessTokenAndSession(u) + req, err := http.NewRequest(http.MethodDelete, fmt.Sprintf("/user/identities/%s", identity.ID), nil) + require.NoError(ts.T(), err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + // sanity checks + u, err = models.FindUserByID(ts.API.db, u.ID) + require.NoError(ts.T(), err) + require.Len(ts.T(), u.Identities, 1) + require.Equal(ts.T(), u.Identities[0].Provider, c.providerRemaining) + + // conditional checks depending on the provider that was unlinked + switch c.provider { + case "phone": + require.Equal(ts.T(), "", u.GetPhone()) + require.Nil(ts.T(), u.PhoneConfirmedAt) + case "email": + require.Equal(ts.T(), "", u.GetEmail()) + require.Nil(ts.T(), u.EmailConfirmedAt) + } + + // user still has a phone / email identity linked so it should not be unconfirmed + require.NotNil(ts.T(), u.ConfirmedAt) + }) + } + +} + +func (ts *IdentityTestSuite) generateAccessTokenAndSession(u *models.User) string { + s, err := models.NewSession(u.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(s)) + + req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil) + token, _, err := ts.API.generateAccessToken(req, ts.API.db, u, &s.ID, models.PasswordGrant) + require.NoError(ts.T(), err) + return token + +} diff --git a/auth_v2.169.0/internal/api/invite.go b/auth_v2.169.0/internal/api/invite.go new file mode 100644 index 0000000..f0260dd --- /dev/null +++ b/auth_v2.169.0/internal/api/invite.go @@ -0,0 +1,92 @@ +package api + +import ( + "net/http" + + "github.com/fatih/structs" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +// InviteParams are the parameters the Signup endpoint accepts +type InviteParams struct { + Email string `json:"email"` + Data map[string]interface{} `json:"data"` +} + +// Invite is the endpoint for inviting a new user +func (a *API) Invite(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + adminUser := getAdminUser(ctx) + params := &InviteParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + var err error + params.Email, err = a.validateEmail(params.Email) + if err != nil { + return err + } + + aud := a.requestAud(ctx, r) + user, err := models.FindUserByEmailAndAudience(db, params.Email, aud) + if err != nil && !models.IsNotFoundError(err) { + return internalServerError("Database error finding user").WithInternalError(err) + } + + err = db.Transaction(func(tx *storage.Connection) error { + if user != nil { + if user.IsConfirmed() { + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) + } + } else { + signupParams := SignupParams{ + Email: params.Email, + Data: params.Data, + Aud: aud, + Provider: "email", + } + + // because params above sets no password, this method + // is not computationally hard so it can be used within + // a database transaction + user, err = signupParams.ToUserModel(false /* <- isSSOUser */) + if err != nil { + return err + } + + user, err = a.signupNewUser(tx, user) + if err != nil { + return err + } + identity, err := a.createNewIdentity(tx, user, "email", structs.Map(provider.Claims{ + Subject: user.ID.String(), + Email: user.GetEmail(), + })) + if err != nil { + return err + } + user.Identities = []models.Identity{*identity} + } + + if terr := models.NewAuditLogEntry(r, tx, adminUser, models.UserInvitedAction, "", map[string]interface{}{ + "user_id": user.ID, + "user_email": user.Email, + }); terr != nil { + return terr + } + + if err := a.sendInvite(r, tx, user); err != nil { + return err + } + return nil + }) + if err != nil { + return err + } + + return sendJSON(w, http.StatusOK, user) +} diff --git a/auth_v2.169.0/internal/api/invite_test.go b/auth_v2.169.0/internal/api/invite_test.go new file mode 100644 index 0000000..ff0baca --- /dev/null +++ b/auth_v2.169.0/internal/api/invite_test.go @@ -0,0 +1,404 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + jwt "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/models" +) + +type InviteTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration + + token string +} + +func TestInvite(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &InviteTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *InviteTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + // Setup response recorder with super admin privileges + ts.token = ts.makeSuperAdmin("") +} + +func (ts *InviteTestSuite) makeSuperAdmin(email string) string { + // Cleanup existing user, if they already exist + if u, _ := models.FindUserByEmailAndAudience(ts.API.db, email, ts.Config.JWT.Aud); u != nil { + require.NoError(ts.T(), ts.API.db.Destroy(u), "Error deleting user") + } + + u, err := models.NewUser("123456789", email, "test", ts.Config.JWT.Aud, map[string]interface{}{"full_name": "Test User"}) + require.NoError(ts.T(), err, "Error making new user") + require.NoError(ts.T(), ts.API.db.Create(u)) + + u.Role = "supabase_admin" + + var token string + + session, err := models.NewSession(u.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(session)) + + req := httptest.NewRequest(http.MethodPost, "/invite", nil) + token, _, err = ts.API.generateAccessToken(req, ts.API.db, u, &session.ID, models.Invite) + + require.NoError(ts.T(), err, "Error generating access token") + + p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name})) + _, err = p.Parse(token, func(token *jwt.Token) (interface{}, error) { + return []byte(ts.Config.JWT.Secret), nil + }) + require.NoError(ts.T(), err, "Error parsing token") + + return token +} + +func (ts *InviteTestSuite) TestInvite() { + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + "data": map[string]interface{}{ + "a": 1, + }, + })) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/invite", &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + // Setup response recorder + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) +} + +func (ts *InviteTestSuite) TestInviteAfterSignupShouldNotReturnSensitiveFields() { + // To allow us to send signup and invite request in succession + ts.Config.SMTP.MaxFrequency = 5 + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + "data": map[string]interface{}{ + "a": 1, + }, + })) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/invite", &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + // Setup response recorder + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + "password": "test123", + "data": map[string]interface{}{ + "a": 1, + }, + })) + + // Setup request + req = httptest.NewRequest(http.MethodPost, "/signup", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + x := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(x, req) + + require.Equal(ts.T(), http.StatusOK, x.Code) + + data := models.User{} + require.NoError(ts.T(), json.NewDecoder(x.Body).Decode(&data)) + // Sensitive fields + require.Equal(ts.T(), 0, len(data.Identities)) + require.Equal(ts.T(), 0, len(data.UserMetaData)) +} + +func (ts *InviteTestSuite) TestInvite_WithoutAccess() { + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + "data": map[string]interface{}{ + "a": 1, + }, + })) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/invite", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusUnauthorized, w.Code) // 401 OK because the invite request above has no Authorization header +} + +func (ts *InviteTestSuite) TestVerifyInvite() { + cases := []struct { + desc string + email string + requestBody map[string]interface{} + expected int + }{ + { + "Verify invite with password", + "test@example.com", + map[string]interface{}{ + "email": "test@example.com", + "type": "invite", + "token": "asdf", + "password": "testing", + }, + http.StatusOK, + }, + { + "Verify invite with no password", + "test1@example.com", + map[string]interface{}{ + "email": "test1@example.com", + "type": "invite", + "token": "asdf", + }, + http.StatusOK, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + user, err := models.NewUser("", c.email, "", ts.Config.JWT.Aud, nil) + now := time.Now() + user.InvitedAt = &now + user.ConfirmationSentAt = &now + user.EncryptedPassword = nil + user.ConfirmationToken = crypto.GenerateTokenHash(c.email, c.requestBody["token"].(string)) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(user)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, user.ID, user.GetEmail(), user.ConfirmationToken, models.ConfirmationToken)) + + // Find test user + _, err = models.FindUserByEmailAndAudience(ts.API.db, c.email, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.requestBody)) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), c.expected, w.Code, w.Body.String()) + }) + } +} + +func (ts *InviteTestSuite) TestInviteExternalGitlab() { + tokenCount, userCount := 0, 0 + code := "authcode" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth/token": + tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.Gitlab.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"gitlab_token","expires_in":100000}`) + case "/api/v4/user": + userCount++ + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"name":"Gitlab Test","email":"gitlab@example.com","avatar_url":"http://example.com/avatar","confirmed_at": "2020-01-01T00:00:00.000Z"}`) + case "/api/v4/user/emails": + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `[]`) + default: + w.WriteHeader(http.StatusInternalServerError) + ts.Fail("unknown gitlab oauth call %s", r.URL.Path) + } + })) + defer server.Close() + ts.Config.External.Gitlab.URL = server.URL + + // invite user + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(InviteParams{ + Email: "gitlab@example.com", + })) + req := httptest.NewRequest(http.MethodPost, "http://localhost/invite", &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusOK, w.Code) + + // Find test user + user, err := models.FindUserByEmailAndAudience(ts.API.db, "gitlab@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + // get redirect url w/ state + req = httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=gitlab&invite_token="+user.ConfirmationToken, nil) + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + state := q.Get("state") + + // auth server callback + testURL, err := url.Parse("http://localhost/callback") + ts.Require().NoError(err) + v := testURL.Query() + v.Set("code", code) + v.Set("state", state) + testURL.RawQuery = v.Encode() + req = httptest.NewRequest(http.MethodGet, testURL.String(), nil) + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err = url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + + // ensure redirect has #access_token=... + v, err = url.ParseQuery(u.Fragment) + ts.Require().NoError(err) + ts.Require().Empty(v.Get("error_description")) + ts.Require().Empty(v.Get("error")) + + ts.NotEmpty(v.Get("access_token")) + ts.NotEmpty(v.Get("refresh_token")) + ts.NotEmpty(v.Get("expires_in")) + ts.Equal("bearer", v.Get("token_type")) + + ts.Equal(1, tokenCount) + ts.Equal(1, userCount) + + // ensure user has been created with metadata + user, err = models.FindUserByEmailAndAudience(ts.API.db, "gitlab@example.com", ts.Config.JWT.Aud) + ts.Require().NoError(err) + ts.Equal("Gitlab Test", user.UserMetaData["full_name"]) + ts.Equal("http://example.com/avatar", user.UserMetaData["avatar_url"]) + ts.Equal("gitlab", user.AppMetaData["provider"]) + ts.Equal([]interface{}{"gitlab"}, user.AppMetaData["providers"]) +} + +func (ts *InviteTestSuite) TestInviteExternalGitlab_MismatchedEmails() { + tokenCount, userCount := 0, 0 + code := "authcode" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/oauth/token": + tokenCount++ + ts.Equal(code, r.FormValue("code")) + ts.Equal("authorization_code", r.FormValue("grant_type")) + ts.Equal(ts.Config.External.Gitlab.RedirectURI, r.FormValue("redirect_uri")) + + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"gitlab_token","expires_in":100000}`) + case "/api/v4/user": + userCount++ + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `{"name":"Gitlab Test","email":"gitlab+mismatch@example.com","avatar_url":"http://example.com/avatar","confirmed_at": "2020-01-01T00:00:00.000Z"}`) + case "/api/v4/user/emails": + w.Header().Add("Content-Type", "application/json") + fmt.Fprint(w, `[]`) + default: + w.WriteHeader(500) + ts.Fail("unknown gitlab oauth call %s", r.URL.Path) + } + })) + defer server.Close() + ts.Config.External.Gitlab.URL = server.URL + + // invite user + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(InviteParams{ + Email: "gitlab@example.com", + })) + req := httptest.NewRequest(http.MethodPost, "http://localhost/invite", &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusOK, w.Code) + + // Find test user + user, err := models.FindUserByEmailAndAudience(ts.API.db, "gitlab@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + // get redirect url w/ state + req = httptest.NewRequest(http.MethodGet, "http://localhost/authorize?provider=gitlab&invite_token="+user.ConfirmationToken, nil) + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err := url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + q := u.Query() + state := q.Get("state") + + // auth server callback + testURL, err := url.Parse("http://localhost/callback") + ts.Require().NoError(err) + v := testURL.Query() + v.Set("code", code) + v.Set("state", state) + testURL.RawQuery = v.Encode() + req = httptest.NewRequest(http.MethodGet, testURL.String(), nil) + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + ts.Require().Equal(http.StatusFound, w.Code) + u, err = url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + + // ensure redirect has #access_token=... + v, err = url.ParseQuery(u.RawQuery) + ts.Require().NoError(err, u.RawQuery) + ts.Require().NotEmpty(v.Get("error_description")) + ts.Require().Equal("invalid_request", v.Get("error")) +} diff --git a/auth_v2.169.0/internal/api/jwks.go b/auth_v2.169.0/internal/api/jwks.go new file mode 100644 index 0000000..b8304d2 --- /dev/null +++ b/auth_v2.169.0/internal/api/jwks.go @@ -0,0 +1,61 @@ +package api + +import ( + "net/http" + + jwt "github.com/golang-jwt/jwt/v5" + "github.com/lestrrat-go/jwx/v2/jwa" + jwk "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/supabase/auth/internal/conf" +) + +type JwksResponse struct { + Keys []jwk.Key `json:"keys"` +} + +func (a *API) Jwks(w http.ResponseWriter, r *http.Request) error { + config := a.config + resp := JwksResponse{ + Keys: []jwk.Key{}, + } + + for _, key := range config.JWT.Keys { + // don't expose hmac jwk in endpoint + if key.PublicKey == nil || key.PublicKey.KeyType() == jwa.OctetSeq { + continue + } + resp.Keys = append(resp.Keys, key.PublicKey) + } + + w.Header().Set("Cache-Control", "public, max-age=600") + return sendJSON(w, http.StatusOK, resp) +} + +func signJwt(config *conf.JWTConfiguration, claims jwt.Claims) (string, error) { + signingJwk, err := conf.GetSigningJwk(config) + if err != nil { + return "", err + } + signingMethod := conf.GetSigningAlg(signingJwk) + token := jwt.NewWithClaims(signingMethod, claims) + if token.Header == nil { + token.Header = make(map[string]interface{}) + } + + if _, ok := token.Header["kid"]; !ok { + if kid := signingJwk.KeyID(); kid != "" { + token.Header["kid"] = kid + } + } + // this serializes the aud claim to a string + jwt.MarshalSingleStringAsArray = false + signingKey, err := conf.GetSigningKey(signingJwk) + if err != nil { + return "", err + } + signed, err := token.SignedString(signingKey) + if err != nil { + return "", err + } + return signed, nil +} diff --git a/auth_v2.169.0/internal/api/jwks_test.go b/auth_v2.169.0/internal/api/jwks_test.go new file mode 100644 index 0000000..786d343 --- /dev/null +++ b/auth_v2.169.0/internal/api/jwks_test.go @@ -0,0 +1,79 @@ +package api + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/conf" +) + +func TestJwks(t *testing.T) { + // generate RSA key pair for testing + rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + rsaJwkPrivate, err := jwk.FromRaw(rsaPrivateKey) + require.NoError(t, err) + rsaJwkPublic, err := rsaJwkPrivate.PublicKey() + require.NoError(t, err) + kid := rsaJwkPublic.KeyID() + + cases := []struct { + desc string + config conf.JWTConfiguration + expectedLen int + }{ + { + desc: "hmac key should not be returned", + config: conf.JWTConfiguration{ + Aud: "authenticated", + Secret: "test-secret", + }, + expectedLen: 0, + }, + { + desc: "rsa public key returned", + config: conf.JWTConfiguration{ + Aud: "authenticated", + Secret: "test-secret", + Keys: conf.JwtKeysDecoder{ + kid: conf.JwkInfo{ + PublicKey: rsaJwkPublic, + PrivateKey: rsaJwkPrivate, + }, + }, + }, + expectedLen: 1, + }, + } + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + mockAPI, _, err := setupAPIForTest() + require.NoError(t, err) + mockAPI.config.JWT = c.config + + req := httptest.NewRequest(http.MethodGet, "/.well-known/jwks.json", nil) + w := httptest.NewRecorder() + mockAPI.handler.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + + var data map[string]interface{} + require.NoError(t, json.NewDecoder(w.Body).Decode(&data)) + require.Len(t, data["keys"], c.expectedLen) + + for _, key := range data["keys"].([]interface{}) { + bytes, err := json.Marshal(key) + require.NoError(t, err) + actualKey, err := jwk.ParseKey(bytes) + require.NoError(t, err) + require.Equal(t, c.config.Keys[kid].PublicKey, actualKey) + } + }) + } +} diff --git a/auth_v2.169.0/internal/api/logout.go b/auth_v2.169.0/internal/api/logout.go new file mode 100644 index 0000000..8afec6a --- /dev/null +++ b/auth_v2.169.0/internal/api/logout.go @@ -0,0 +1,73 @@ +package api + +import ( + "fmt" + "net/http" + + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +type LogoutBehavior string + +const ( + LogoutGlobal LogoutBehavior = "global" + LogoutLocal LogoutBehavior = "local" + LogoutOthers LogoutBehavior = "others" +) + +// Logout is the endpoint for logging out a user and thereby revoking any refresh tokens +func (a *API) Logout(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + scope := LogoutGlobal + + if r.URL.Query() != nil { + switch r.URL.Query().Get("scope") { + case "", "global": + scope = LogoutGlobal + + case "local": + scope = LogoutLocal + + case "others": + scope = LogoutOthers + + default: + return badRequestError(ErrorCodeValidationFailed, fmt.Sprintf("Unsupported logout scope %q", r.URL.Query().Get("scope"))) + } + } + + s := getSession(ctx) + u := getUser(ctx) + + err := db.Transaction(func(tx *storage.Connection) error { + if terr := models.NewAuditLogEntry(r, tx, u, models.LogoutAction, "", nil); terr != nil { + return terr + } + + if s == nil { + logrus.Infof("user has an empty session_id claim: %s", u.ID) + } else { + //exhaustive:ignore Default case is handled below. + switch scope { + case LogoutLocal: + return models.LogoutSession(tx, s.ID) + + case LogoutOthers: + return models.LogoutAllExceptMe(tx, s.ID, u.ID) + } + } + + // default mode, log out everywhere + return models.Logout(tx, u.ID) + }) + if err != nil { + return internalServerError("Error logging out user").WithInternalError(err) + } + + w.WriteHeader(http.StatusNoContent) + + return nil +} diff --git a/auth_v2.169.0/internal/api/logout_test.go b/auth_v2.169.0/internal/api/logout_test.go new file mode 100644 index 0000000..b1a0fdb --- /dev/null +++ b/auth_v2.169.0/internal/api/logout_test.go @@ -0,0 +1,75 @@ +package api + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +type LogoutTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration + token string +} + +func TestLogout(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &LogoutTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *LogoutTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + u, err := models.NewUser("", "test@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") + + // generate access token to use for logout + var t string + s, err := models.NewSession(u.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(s)) + + req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil) + t, _, err = ts.API.generateAccessToken(req, ts.API.db, u, &s.ID, models.PasswordGrant) + require.NoError(ts.T(), err) + ts.token = t +} + +func (ts *LogoutTestSuite) TestLogoutSuccess() { + for _, scope := range []string{"", "global", "local", "others"} { + ts.SetupTest() + + reqURL, err := url.ParseRequestURI("http://localhost/logout") + require.NoError(ts.T(), err) + + if scope != "" { + query := reqURL.Query() + query.Set("scope", scope) + reqURL.RawQuery = query.Encode() + } + + req := httptest.NewRequest(http.MethodPost, reqURL.String(), nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusNoContent, w.Code) + } +} diff --git a/auth_v2.169.0/internal/api/magic_link.go b/auth_v2.169.0/internal/api/magic_link.go new file mode 100644 index 0000000..57b0a7d --- /dev/null +++ b/auth_v2.169.0/internal/api/magic_link.go @@ -0,0 +1,164 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +// MagicLinkParams holds the parameters for a magic link request +type MagicLinkParams struct { + Email string `json:"email"` + Data map[string]interface{} `json:"data"` + CodeChallengeMethod string `json:"code_challenge_method"` + CodeChallenge string `json:"code_challenge"` +} + +func (p *MagicLinkParams) Validate(a *API) error { + if p.Email == "" { + return unprocessableEntityError(ErrorCodeValidationFailed, "Password recovery requires an email") + } + var err error + p.Email, err = a.validateEmail(p.Email) + if err != nil { + return err + } + if err := validatePKCEParams(p.CodeChallengeMethod, p.CodeChallenge); err != nil { + return err + } + return nil +} + +// MagicLink sends a recovery email +func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + config := a.config + + if !config.External.Email.Enabled { + return unprocessableEntityError(ErrorCodeEmailProviderDisabled, "Email logins are disabled") + } + + if !config.External.Email.MagicLinkEnabled { + return unprocessableEntityError(ErrorCodeEmailProviderDisabled, "Login with magic link is disabled") + } + + params := &MagicLinkParams{} + jsonDecoder := json.NewDecoder(r.Body) + err := jsonDecoder.Decode(params) + if err != nil { + return badRequestError(ErrorCodeBadJSON, "Could not read verification params: %v", err).WithInternalError(err) + } + + if err := params.Validate(a); err != nil { + return err + } + + if params.Data == nil { + params.Data = make(map[string]interface{}) + } + + flowType := getFlowFromChallenge(params.CodeChallenge) + + var isNewUser bool + aud := a.requestAud(ctx, r) + user, err := models.FindUserByEmailAndAudience(db, params.Email, aud) + if err != nil { + if models.IsNotFoundError(err) { + isNewUser = true + } else { + return internalServerError("Database error finding user").WithInternalError(err) + } + } + if user != nil { + isNewUser = !user.IsConfirmed() + } + if isNewUser { + // User either doesn't exist or hasn't completed the signup process. + // Sign them up with temporary password. + password := crypto.GeneratePassword(config.Password.RequiredCharacters, 33) + + signUpParams := &SignupParams{ + Email: params.Email, + Password: password, + Data: params.Data, + CodeChallengeMethod: params.CodeChallengeMethod, + CodeChallenge: params.CodeChallenge, + } + newBodyContent, err := json.Marshal(signUpParams) + if err != nil { + // SignupParams must always be marshallable + panic(fmt.Errorf("failed to marshal SignupParams: %w", err)) + } + r.Body = io.NopCloser(strings.NewReader(string(newBodyContent))) + r.ContentLength = int64(len(string(newBodyContent))) + + fakeResponse := &responseStub{} + if config.Mailer.Autoconfirm { + // signups are autoconfirmed, send magic link after signup + if err := a.Signup(fakeResponse, r); err != nil { + return err + } + newBodyContent := &SignupParams{ + Email: params.Email, + Data: params.Data, + CodeChallengeMethod: params.CodeChallengeMethod, + CodeChallenge: params.CodeChallenge, + } + metadata, err := json.Marshal(newBodyContent) + if err != nil { + // SignupParams must always be marshallable + panic(fmt.Errorf("failed to marshal SignupParams: %w", err)) + } + r.Body = io.NopCloser(bytes.NewReader(metadata)) + return a.MagicLink(w, r) + } + // otherwise confirmation email already contains 'magic link' + if err := a.Signup(fakeResponse, r); err != nil { + return err + } + + return sendJSON(w, http.StatusOK, make(map[string]string)) + } + + if isPKCEFlow(flowType) { + if _, err = generateFlowState(a.db, models.MagicLink.String(), models.MagicLink, params.CodeChallengeMethod, params.CodeChallenge, &user.ID); err != nil { + return err + } + } + + err = db.Transaction(func(tx *storage.Connection) error { + if terr := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil { + return terr + } + return a.sendMagicLink(r, tx, user, flowType) + }) + if err != nil { + return err + } + + return sendJSON(w, http.StatusOK, make(map[string]string)) +} + +// responseStub only implement http responsewriter for ignoring +// incoming data from methods where it passed +type responseStub struct { +} + +func (rw *responseStub) Header() http.Header { + return http.Header{} +} + +func (rw *responseStub) Write(data []byte) (int, error) { + return 1, nil +} + +func (rw *responseStub) WriteHeader(statusCode int) { +} diff --git a/auth_v2.169.0/internal/api/mail.go b/auth_v2.169.0/internal/api/mail.go new file mode 100644 index 0000000..f2ea69b --- /dev/null +++ b/auth_v2.169.0/internal/api/mail.go @@ -0,0 +1,685 @@ +package api + +import ( + "net/http" + "regexp" + "strings" + "time" + + "github.com/supabase/auth/internal/hooks" + mail "github.com/supabase/auth/internal/mailer" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + + "github.com/badoux/checkmail" + "github.com/fatih/structs" + "github.com/pkg/errors" + "github.com/sethvargo/go-password/password" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" +) + +var ( + EmailRateLimitExceeded error = errors.New("email rate limit exceeded") +) + +type GenerateLinkParams struct { + Type string `json:"type"` + Email string `json:"email"` + NewEmail string `json:"new_email"` + Password string `json:"password"` + Data map[string]interface{} `json:"data"` + RedirectTo string `json:"redirect_to"` +} + +type GenerateLinkResponse struct { + models.User + ActionLink string `json:"action_link"` + EmailOtp string `json:"email_otp"` + HashedToken string `json:"hashed_token"` + VerificationType string `json:"verification_type"` + RedirectTo string `json:"redirect_to"` +} + +func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + config := a.config + mailer := a.Mailer() + adminUser := getAdminUser(ctx) + params := &GenerateLinkParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + var err error + params.Email, err = a.validateEmail(params.Email) + if err != nil { + return err + } + referrer := utilities.GetReferrer(r, config) + if utilities.IsRedirectURLValid(config, params.RedirectTo) { + referrer = params.RedirectTo + } + + aud := a.requestAud(ctx, r) + user, err := models.FindUserByEmailAndAudience(db, params.Email, aud) + if err != nil { + if models.IsNotFoundError(err) { + switch params.Type { + case mail.MagicLinkVerification: + params.Type = mail.SignupVerification + params.Password, err = password.Generate(64, 10, 1, false, true) + if err != nil { + // password generation must always succeed + panic(err) + } + case mail.RecoveryVerification, mail.EmailChangeCurrentVerification, mail.EmailChangeNewVerification: + return notFoundError(ErrorCodeUserNotFound, "User with this email not found") + } + } else { + return internalServerError("Database error finding user").WithInternalError(err) + } + } + + var url string + now := time.Now() + otp := crypto.GenerateOtp(config.Mailer.OtpLength) + + hashedToken := crypto.GenerateTokenHash(params.Email, otp) + + var signupUser *models.User + if params.Type == mail.SignupVerification && user == nil { + signupParams := &SignupParams{ + Email: params.Email, + Password: params.Password, + Data: params.Data, + Provider: "email", + Aud: aud, + } + + if err := a.validateSignupParams(ctx, signupParams); err != nil { + return err + } + + signupUser, err = signupParams.ToUserModel(false /* <- isSSOUser */) + if err != nil { + return err + } + } + + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + switch params.Type { + case mail.MagicLinkVerification, mail.RecoveryVerification: + if terr = models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil { + return terr + } + user.RecoveryToken = hashedToken + user.RecoverySentAt = &now + terr = tx.UpdateOnly(user, "recovery_token", "recovery_sent_at") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for recovery") + return terr + } + + terr = models.CreateOneTimeToken(tx, user.ID, user.GetEmail(), user.RecoveryToken, models.RecoveryToken) + if terr != nil { + terr = errors.Wrap(terr, "Database error creating recovery token in admin") + return terr + } + case mail.InviteVerification: + if user != nil { + if user.IsConfirmed() { + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) + } + } else { + signupParams := &SignupParams{ + Email: params.Email, + Data: params.Data, + Provider: "email", + Aud: aud, + } + + // because params above sets no password, this + // method is not computationally hard so it can + // be used within a database transaction + user, terr = signupParams.ToUserModel(false /* <- isSSOUser */) + if terr != nil { + return terr + } + + user, terr = a.signupNewUser(tx, user) + if terr != nil { + return terr + } + identity, terr := a.createNewIdentity(tx, user, "email", structs.Map(provider.Claims{ + Subject: user.ID.String(), + Email: user.GetEmail(), + })) + if terr != nil { + return terr + } + user.Identities = []models.Identity{*identity} + } + if terr = models.NewAuditLogEntry(r, tx, adminUser, models.UserInvitedAction, "", map[string]interface{}{ + "user_id": user.ID, + "user_email": user.Email, + }); terr != nil { + return terr + } + user.ConfirmationToken = hashedToken + user.ConfirmationSentAt = &now + user.InvitedAt = &now + terr = tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at", "invited_at") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for invite") + return terr + } + terr = models.CreateOneTimeToken(tx, user.ID, user.GetEmail(), user.ConfirmationToken, models.ConfirmationToken) + if terr != nil { + terr = errors.Wrap(terr, "Database error creating confirmation token for invite in admin") + return terr + } + case mail.SignupVerification: + if user != nil { + if user.IsConfirmed() { + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) + } + if err := user.UpdateUserMetaData(tx, params.Data); err != nil { + return internalServerError("Database error updating user").WithInternalError(err) + } + } else { + // you should never use SignupParams with + // password here to generate a new user, use + // signupUser which is a model generated from + // SignupParams above + user, terr = a.signupNewUser(tx, signupUser) + if terr != nil { + return terr + } + identity, terr := a.createNewIdentity(tx, user, "email", structs.Map(provider.Claims{ + Subject: user.ID.String(), + Email: user.GetEmail(), + })) + if terr != nil { + return terr + } + user.Identities = []models.Identity{*identity} + } + user.ConfirmationToken = hashedToken + user.ConfirmationSentAt = &now + terr = tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for confirmation") + return terr + } + terr = models.CreateOneTimeToken(tx, user.ID, user.GetEmail(), user.ConfirmationToken, models.ConfirmationToken) + if terr != nil { + terr = errors.Wrap(terr, "Database error creating confirmation token for signup in admin") + return terr + } + case mail.EmailChangeCurrentVerification, mail.EmailChangeNewVerification: + if !config.Mailer.SecureEmailChangeEnabled && params.Type == "email_change_current" { + return badRequestError(ErrorCodeValidationFailed, "Enable secure email change to generate link for current email") + } + params.NewEmail, terr = a.validateEmail(params.NewEmail) + if terr != nil { + return terr + } + if duplicateUser, terr := models.IsDuplicatedEmail(tx, params.NewEmail, user.Aud, user); terr != nil { + return internalServerError("Database error checking email").WithInternalError(terr) + } else if duplicateUser != nil { + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) + } + now := time.Now() + user.EmailChangeSentAt = &now + user.EmailChange = params.NewEmail + user.EmailChangeConfirmStatus = zeroConfirmation + if params.Type == "email_change_current" { + user.EmailChangeTokenCurrent = hashedToken + } else if params.Type == "email_change_new" { + user.EmailChangeTokenNew = crypto.GenerateTokenHash(params.NewEmail, otp) + } + terr = tx.UpdateOnly(user, "email_change_token_current", "email_change_token_new", "email_change", "email_change_sent_at", "email_change_confirm_status") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for email change") + return terr + } + if user.EmailChangeTokenCurrent != "" { + terr = models.CreateOneTimeToken(tx, user.ID, user.GetEmail(), user.EmailChangeTokenCurrent, models.EmailChangeTokenCurrent) + if terr != nil { + terr = errors.Wrap(terr, "Database error creating email change token current in admin") + return terr + } + } + if user.EmailChangeTokenNew != "" { + terr = models.CreateOneTimeToken(tx, user.ID, user.EmailChange, user.EmailChangeTokenNew, models.EmailChangeTokenNew) + if terr != nil { + terr = errors.Wrap(terr, "Database error creating email change token new in admin") + return terr + } + } + default: + return badRequestError(ErrorCodeValidationFailed, "Invalid email action link type requested: %v", params.Type) + } + + if terr != nil { + return terr + } + + externalURL := getExternalHost(ctx) + url, terr = mailer.GetEmailActionLink(user, params.Type, referrer, externalURL) + if terr != nil { + return terr + } + return nil + }) + + if err != nil { + return err + } + + resp := GenerateLinkResponse{ + User: *user, + ActionLink: url, + EmailOtp: otp, + HashedToken: hashedToken, + VerificationType: params.Type, + RedirectTo: referrer, + } + + return sendJSON(w, http.StatusOK, resp) +} + +func (a *API) sendConfirmation(r *http.Request, tx *storage.Connection, u *models.User, flowType models.FlowType) error { + var err error + + config := a.config + maxFrequency := config.SMTP.MaxFrequency + otpLength := config.Mailer.OtpLength + + if err = validateSentWithinFrequencyLimit(u.ConfirmationSentAt, maxFrequency); err != nil { + return err + } + oldToken := u.ConfirmationToken + otp := crypto.GenerateOtp(otpLength) + + token := crypto.GenerateTokenHash(u.GetEmail(), otp) + u.ConfirmationToken = addFlowPrefixToToken(token, flowType) + now := time.Now() + if err = a.sendEmail(r, tx, u, mail.SignupVerification, otp, "", u.ConfirmationToken); err != nil { + u.ConfirmationToken = oldToken + if errors.Is(err, EmailRateLimitExceeded) { + return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + } else if herr, ok := err.(*HTTPError); ok { + return herr + } + return internalServerError("Error sending confirmation email").WithInternalError(err) + } + u.ConfirmationSentAt = &now + if err := tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at"); err != nil { + return internalServerError("Error sending confirmation email").WithInternalError(errors.Wrap(err, "Database error updating user for confirmation")) + } + + if err := models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken); err != nil { + return internalServerError("Error sending confirmation email").WithInternalError(errors.Wrap(err, "Database error creating confirmation token")) + } + + return nil +} + +func (a *API) sendInvite(r *http.Request, tx *storage.Connection, u *models.User) error { + config := a.config + otpLength := config.Mailer.OtpLength + var err error + oldToken := u.ConfirmationToken + otp := crypto.GenerateOtp(otpLength) + + u.ConfirmationToken = crypto.GenerateTokenHash(u.GetEmail(), otp) + now := time.Now() + if err = a.sendEmail(r, tx, u, mail.InviteVerification, otp, "", u.ConfirmationToken); err != nil { + u.ConfirmationToken = oldToken + if errors.Is(err, EmailRateLimitExceeded) { + return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + } else if herr, ok := err.(*HTTPError); ok { + return herr + } + return internalServerError("Error sending invite email").WithInternalError(err) + } + u.InvitedAt = &now + u.ConfirmationSentAt = &now + err = tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at", "invited_at") + if err != nil { + return internalServerError("Error inviting user").WithInternalError(errors.Wrap(err, "Database error updating user for invite")) + } + + err = models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken) + if err != nil { + return internalServerError("Error inviting user").WithInternalError(errors.Wrap(err, "Database error creating confirmation token for invite")) + } + + return nil +} + +func (a *API) sendPasswordRecovery(r *http.Request, tx *storage.Connection, u *models.User, flowType models.FlowType) error { + config := a.config + otpLength := config.Mailer.OtpLength + + if err := validateSentWithinFrequencyLimit(u.RecoverySentAt, config.SMTP.MaxFrequency); err != nil { + return err + } + + oldToken := u.RecoveryToken + otp := crypto.GenerateOtp(otpLength) + + token := crypto.GenerateTokenHash(u.GetEmail(), otp) + u.RecoveryToken = addFlowPrefixToToken(token, flowType) + now := time.Now() + if err := a.sendEmail(r, tx, u, mail.RecoveryVerification, otp, "", u.RecoveryToken); err != nil { + u.RecoveryToken = oldToken + if errors.Is(err, EmailRateLimitExceeded) { + return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + } else if herr, ok := err.(*HTTPError); ok { + return herr + } + return internalServerError("Error sending recovery email").WithInternalError(err) + } + u.RecoverySentAt = &now + + if err := tx.UpdateOnly(u, "recovery_token", "recovery_sent_at"); err != nil { + return internalServerError("Error sending recovery email").WithInternalError(errors.Wrap(err, "Database error updating user for recovery")) + } + + if err := models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.RecoveryToken, models.RecoveryToken); err != nil { + return internalServerError("Error sending recovery email").WithInternalError(errors.Wrap(err, "Database error creating recovery token")) + } + + return nil +} + +func (a *API) sendReauthenticationOtp(r *http.Request, tx *storage.Connection, u *models.User) error { + config := a.config + maxFrequency := config.SMTP.MaxFrequency + otpLength := config.Mailer.OtpLength + + if err := validateSentWithinFrequencyLimit(u.ReauthenticationSentAt, maxFrequency); err != nil { + return err + } + + oldToken := u.ReauthenticationToken + otp := crypto.GenerateOtp(otpLength) + + u.ReauthenticationToken = crypto.GenerateTokenHash(u.GetEmail(), otp) + now := time.Now() + + if err := a.sendEmail(r, tx, u, mail.ReauthenticationVerification, otp, "", u.ReauthenticationToken); err != nil { + u.ReauthenticationToken = oldToken + if errors.Is(err, EmailRateLimitExceeded) { + return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + } else if herr, ok := err.(*HTTPError); ok { + return herr + } + return internalServerError("Error sending reauthentication email").WithInternalError(err) + } + u.ReauthenticationSentAt = &now + if err := tx.UpdateOnly(u, "reauthentication_token", "reauthentication_sent_at"); err != nil { + return internalServerError("Error sending reauthentication email").WithInternalError(errors.Wrap(err, "Database error updating user for reauthentication")) + } + + if err := models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.ReauthenticationToken, models.ReauthenticationToken); err != nil { + return internalServerError("Error sending reauthentication email").WithInternalError(errors.Wrap(err, "Database error creating reauthentication token")) + } + + return nil +} + +func (a *API) sendMagicLink(r *http.Request, tx *storage.Connection, u *models.User, flowType models.FlowType) error { + var err error + config := a.config + otpLength := config.Mailer.OtpLength + + // since Magic Link is just a recovery with a different template and behaviour + // around new users we will reuse the recovery db timer to prevent potential abuse + if err := validateSentWithinFrequencyLimit(u.RecoverySentAt, config.SMTP.MaxFrequency); err != nil { + return err + } + + oldToken := u.RecoveryToken + otp := crypto.GenerateOtp(otpLength) + + token := crypto.GenerateTokenHash(u.GetEmail(), otp) + u.RecoveryToken = addFlowPrefixToToken(token, flowType) + + now := time.Now() + if err = a.sendEmail(r, tx, u, mail.MagicLinkVerification, otp, "", u.RecoveryToken); err != nil { + u.RecoveryToken = oldToken + if errors.Is(err, EmailRateLimitExceeded) { + return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + } else if herr, ok := err.(*HTTPError); ok { + return herr + } + return internalServerError("Error sending magic link email").WithInternalError(err) + } + u.RecoverySentAt = &now + if err := tx.UpdateOnly(u, "recovery_token", "recovery_sent_at"); err != nil { + return internalServerError("Error sending magic link email").WithInternalError(errors.Wrap(err, "Database error updating user for recovery")) + } + + if err := models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.RecoveryToken, models.RecoveryToken); err != nil { + return internalServerError("Error sending magic link email").WithInternalError(errors.Wrap(err, "Database error creating recovery token")) + } + + return nil +} + +// sendEmailChange sends out an email change token to the new email. +func (a *API) sendEmailChange(r *http.Request, tx *storage.Connection, u *models.User, email string, flowType models.FlowType) error { + config := a.config + otpLength := config.Mailer.OtpLength + + if err := validateSentWithinFrequencyLimit(u.EmailChangeSentAt, config.SMTP.MaxFrequency); err != nil { + return err + } + + otpNew := crypto.GenerateOtp(otpLength) + + u.EmailChange = email + token := crypto.GenerateTokenHash(u.EmailChange, otpNew) + u.EmailChangeTokenNew = addFlowPrefixToToken(token, flowType) + + otpCurrent := "" + if config.Mailer.SecureEmailChangeEnabled && u.GetEmail() != "" { + otpCurrent = crypto.GenerateOtp(otpLength) + + currentToken := crypto.GenerateTokenHash(u.GetEmail(), otpCurrent) + u.EmailChangeTokenCurrent = addFlowPrefixToToken(currentToken, flowType) + } + + u.EmailChangeConfirmStatus = zeroConfirmation + now := time.Now() + + if err := a.sendEmail(r, tx, u, mail.EmailChangeVerification, otpCurrent, otpNew, u.EmailChangeTokenNew); err != nil { + if errors.Is(err, EmailRateLimitExceeded) { + return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, EmailRateLimitExceeded.Error()) + } else if herr, ok := err.(*HTTPError); ok { + return herr + } + return internalServerError("Error sending email change email").WithInternalError(err) + } + + u.EmailChangeSentAt = &now + if err := tx.UpdateOnly( + u, + "email_change_token_current", + "email_change_token_new", + "email_change", + "email_change_sent_at", + "email_change_confirm_status", + ); err != nil { + return internalServerError("Error sending email change email").WithInternalError(errors.Wrap(err, "Database error updating user for email change")) + } + + if u.EmailChangeTokenCurrent != "" { + if err := models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.EmailChangeTokenCurrent, models.EmailChangeTokenCurrent); err != nil { + return internalServerError("Error sending email change email").WithInternalError(errors.Wrap(err, "Database error creating email change token current")) + } + } + + if u.EmailChangeTokenNew != "" { + if err := models.CreateOneTimeToken(tx, u.ID, u.EmailChange, u.EmailChangeTokenNew, models.EmailChangeTokenNew); err != nil { + return internalServerError("Error sending email change email").WithInternalError(errors.Wrap(err, "Database error creating email change token new")) + } + } + + return nil +} + +func (a *API) validateEmail(email string) (string, error) { + if email == "" { + return "", badRequestError(ErrorCodeValidationFailed, "An email address is required") + } + if len(email) > 255 { + return "", badRequestError(ErrorCodeValidationFailed, "An email address is too long") + } + if err := checkmail.ValidateFormat(email); err != nil { + return "", badRequestError(ErrorCodeValidationFailed, "Unable to validate email address: "+err.Error()) + } + + return strings.ToLower(email), nil +} + +func validateSentWithinFrequencyLimit(sentAt *time.Time, frequency time.Duration) error { + if sentAt != nil && sentAt.Add(frequency).After(time.Now()) { + return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, generateFrequencyLimitErrorMessage(sentAt, frequency)) + } + return nil +} + +var emailLabelPattern = regexp.MustCompile("[+][^@]+@") + +func (a *API) checkEmailAddressAuthorization(email string) bool { + if len(a.config.External.Email.AuthorizedAddresses) > 0 { + // allow labelled emails when authorization rules are in place + normalized := emailLabelPattern.ReplaceAllString(email, "@") + + for _, authorizedAddress := range a.config.External.Email.AuthorizedAddresses { + if strings.EqualFold(normalized, authorizedAddress) { + return true + } + } + + return false + } + + return true +} + +func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User, emailActionType, otp, otpNew, tokenHashWithPrefix string) error { + ctx := r.Context() + config := a.config + referrerURL := utilities.GetReferrer(r, config) + externalURL := getExternalHost(ctx) + + if emailActionType != mail.EmailChangeVerification { + if u.GetEmail() != "" && !a.checkEmailAddressAuthorization(u.GetEmail()) { + return badRequestError(ErrorCodeEmailAddressNotAuthorized, "Email address %q cannot be used as it is not authorized", u.GetEmail()) + } + } else { + // first check that the user can update their address to the + // new one in u.EmailChange + if u.EmailChange != "" && !a.checkEmailAddressAuthorization(u.EmailChange) { + return badRequestError(ErrorCodeEmailAddressNotAuthorized, "Email address %q cannot be used as it is not authorized", u.EmailChange) + } + + // if secure email change is enabled, check that the user + // account (which could have been created before the authorized + // address authorization restriction was enabled) can even + // receive the confirmation message to the existing address + if config.Mailer.SecureEmailChangeEnabled && u.GetEmail() != "" && !a.checkEmailAddressAuthorization(u.GetEmail()) { + return badRequestError(ErrorCodeEmailAddressNotAuthorized, "Email address %q cannot be used as it is not authorized", u.GetEmail()) + } + } + + // if the number of events is set to zero, we immediately apply rate limits. + if config.RateLimitEmailSent.Events == 0 { + emailRateLimitCounter.Add( + ctx, + 1, + metric.WithAttributeSet(attribute.NewSet(attribute.String("path", r.URL.Path))), + ) + return EmailRateLimitExceeded + } + + // TODO(km): Deprecate this behaviour - rate limits should still be applied to autoconfirm + if !config.Mailer.Autoconfirm { + // apply rate limiting before the email is sent out + if ok := a.limiterOpts.Email.Allow(); !ok { + emailRateLimitCounter.Add( + ctx, + 1, + metric.WithAttributeSet(attribute.NewSet(attribute.String("path", r.URL.Path))), + ) + return EmailRateLimitExceeded + } + } + + if config.Hook.SendEmail.Enabled { + // When secure email change is disabled, we place the token for the new email on emailData.Token + if emailActionType == mail.EmailChangeVerification && !config.Mailer.SecureEmailChangeEnabled && u.GetEmail() != "" { + otp = otpNew + } + + emailData := mail.EmailData{ + Token: otp, + EmailActionType: emailActionType, + RedirectTo: referrerURL, + SiteURL: externalURL.String(), + TokenHash: tokenHashWithPrefix, + } + if emailActionType == mail.EmailChangeVerification && config.Mailer.SecureEmailChangeEnabled && u.GetEmail() != "" { + emailData.TokenNew = otpNew + emailData.TokenHashNew = u.EmailChangeTokenCurrent + } + input := hooks.SendEmailInput{ + User: u, + EmailData: emailData, + } + output := hooks.SendEmailOutput{} + return a.invokeHook(tx, r, &input, &output) + } + + mr := a.Mailer() + var err error + switch emailActionType { + case mail.SignupVerification: + err = mr.ConfirmationMail(r, u, otp, referrerURL, externalURL) + case mail.MagicLinkVerification: + err = mr.MagicLinkMail(r, u, otp, referrerURL, externalURL) + case mail.ReauthenticationVerification: + err = mr.ReauthenticateMail(r, u, otp) + case mail.RecoveryVerification: + err = mr.RecoveryMail(r, u, otp, referrerURL, externalURL) + case mail.InviteVerification: + err = mr.InviteMail(r, u, otp, referrerURL, externalURL) + case mail.EmailChangeVerification: + err = mr.EmailChangeMail(r, u, otpNew, otp, referrerURL, externalURL) + default: + err = errors.New("invalid email action type") + } + + switch { + case errors.Is(err, mail.ErrInvalidEmailAddress), + errors.Is(err, mail.ErrInvalidEmailFormat), + errors.Is(err, mail.ErrInvalidEmailDNS): + return badRequestError( + ErrorCodeEmailAddressInvalid, + "Email address %q is invalid", + u.GetEmail()) + default: + return err + } +} diff --git a/auth_v2.169.0/internal/api/mail_test.go b/auth_v2.169.0/internal/api/mail_test.go new file mode 100644 index 0000000..87fa946 --- /dev/null +++ b/auth_v2.169.0/internal/api/mail_test.go @@ -0,0 +1,256 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gobwas/glob" + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/models" +) + +type MailTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestMail(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &MailTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *MailTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + ts.Config.Mailer.SecureEmailChangeEnabled = true + + // Create User + u, err := models.NewUser("12345678", "test@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating new user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new user") +} + +func (ts *MailTestSuite) TestValidateEmail() { + cases := []struct { + desc string + email string + expectedEmail string + expectedError error + }{ + { + desc: "valid email", + email: "test@example.com", + expectedEmail: "test@example.com", + expectedError: nil, + }, + { + desc: "email should be normalized", + email: "TEST@EXAMPLE.COM", + expectedEmail: "test@example.com", + expectedError: nil, + }, + { + desc: "empty email should return error", + email: "", + expectedEmail: "", + expectedError: badRequestError(ErrorCodeValidationFailed, "An email address is required"), + }, + { + desc: "email length exceeds 255 characters", + // email has 256 characters + email: "testtesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttesttest@example.com", + expectedEmail: "", + expectedError: badRequestError(ErrorCodeValidationFailed, "An email address is too long"), + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + email, err := ts.API.validateEmail(c.email) + require.Equal(ts.T(), c.expectedError, err) + require.Equal(ts.T(), c.expectedEmail, email) + }) + } +} + +func (ts *MailTestSuite) TestGenerateLink() { + // create admin jwt + claims := &AccessTokenClaims{ + Role: "supabase_admin", + } + token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(ts.Config.JWT.Secret)) + require.NoError(ts.T(), err, "Error generating admin jwt") + + ts.setURIAllowListMap("http://localhost:8000/**") + // create test cases + cases := []struct { + Desc string + Body GenerateLinkParams + ExpectedCode int + ExpectedResponse map[string]interface{} + }{ + { + Desc: "Generate signup link for new user", + Body: GenerateLinkParams{ + Email: "new_user@example.com", + Password: "secret123", + Type: "signup", + }, + ExpectedCode: http.StatusOK, + ExpectedResponse: map[string]interface{}{ + "redirect_to": ts.Config.SiteURL, + }, + }, + { + Desc: "Generate signup link for existing user", + Body: GenerateLinkParams{ + Email: "test@example.com", + Password: "secret123", + Type: "signup", + }, + ExpectedCode: http.StatusOK, + ExpectedResponse: map[string]interface{}{ + "redirect_to": ts.Config.SiteURL, + }, + }, + { + Desc: "Generate signup link with custom redirect url", + Body: GenerateLinkParams{ + Email: "test@example.com", + Password: "secret123", + Type: "signup", + RedirectTo: "http://localhost:8000/welcome", + }, + ExpectedCode: http.StatusOK, + ExpectedResponse: map[string]interface{}{ + "redirect_to": "http://localhost:8000/welcome", + }, + }, + { + Desc: "Generate magic link", + Body: GenerateLinkParams{ + Email: "test@example.com", + Type: "magiclink", + }, + ExpectedCode: http.StatusOK, + ExpectedResponse: map[string]interface{}{ + "redirect_to": ts.Config.SiteURL, + }, + }, + { + Desc: "Generate invite link", + Body: GenerateLinkParams{ + Email: "test@example.com", + Type: "invite", + }, + ExpectedCode: http.StatusOK, + ExpectedResponse: map[string]interface{}{ + "redirect_to": ts.Config.SiteURL, + }, + }, + { + Desc: "Generate recovery link", + Body: GenerateLinkParams{ + Email: "test@example.com", + Type: "recovery", + }, + ExpectedCode: http.StatusOK, + ExpectedResponse: map[string]interface{}{ + "redirect_to": ts.Config.SiteURL, + }, + }, + { + Desc: "Generate email change link", + Body: GenerateLinkParams{ + Email: "test@example.com", + NewEmail: "new@example.com", + Type: "email_change_current", + }, + ExpectedCode: http.StatusOK, + ExpectedResponse: map[string]interface{}{ + "redirect_to": ts.Config.SiteURL, + }, + }, + { + Desc: "Generate email change link", + Body: GenerateLinkParams{ + Email: "test@example.com", + NewEmail: "new@example.com", + Type: "email_change_new", + }, + ExpectedCode: http.StatusOK, + ExpectedResponse: map[string]interface{}{ + "redirect_to": ts.Config.SiteURL, + }, + }, + } + + customDomainUrl, err := url.ParseRequestURI("https://example.gotrue.com") + require.NoError(ts.T(), err) + + originalHosts := ts.API.config.Mailer.ExternalHosts + ts.API.config.Mailer.ExternalHosts = []string{ + "example.gotrue.com", + } + + for _, c := range cases { + ts.Run(c.Desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.Body)) + req := httptest.NewRequest(http.MethodPost, customDomainUrl.String()+"/admin/generate_link", &buffer) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), c.ExpectedCode, w.Code) + + data := make(map[string]interface{}) + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + require.Contains(ts.T(), data, "action_link") + require.Contains(ts.T(), data, "email_otp") + require.Contains(ts.T(), data, "hashed_token") + require.Contains(ts.T(), data, "redirect_to") + require.Equal(ts.T(), c.Body.Type, data["verification_type"]) + + // check if redirect_to is correct + require.Equal(ts.T(), c.ExpectedResponse["redirect_to"], data["redirect_to"]) + + // check if hashed_token matches hash function of email and the raw otp + require.Equal(ts.T(), crypto.GenerateTokenHash(c.Body.Email, data["email_otp"].(string)), data["hashed_token"]) + + // check if the host used in the email link matches the initial request host + u, err := url.ParseRequestURI(data["action_link"].(string)) + require.NoError(ts.T(), err) + require.Equal(ts.T(), req.Host, u.Host) + }) + } + + ts.API.config.Mailer.ExternalHosts = originalHosts +} + +func (ts *MailTestSuite) setURIAllowListMap(uris ...string) { + for _, uri := range uris { + g := glob.MustCompile(uri, '.', '/') + ts.Config.URIAllowListMap[uri] = g + } +} diff --git a/auth_v2.169.0/internal/api/mfa.go b/auth_v2.169.0/internal/api/mfa.go new file mode 100644 index 0000000..4ac2b9b --- /dev/null +++ b/auth_v2.169.0/internal/api/mfa.go @@ -0,0 +1,1029 @@ +package api + +import ( + "bytes" + "crypto/subtle" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/aaronarduino/goqrsvg" + svg "github.com/ajstarks/svgo" + "github.com/boombuler/barcode/qr" + wbnprotocol "github.com/go-webauthn/webauthn/protocol" + "github.com/go-webauthn/webauthn/webauthn" + "github.com/gofrs/uuid" + "github.com/pquerna/otp" + "github.com/pquerna/otp/totp" + "github.com/supabase/auth/internal/api/sms_provider" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/hooks" + "github.com/supabase/auth/internal/metering" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" +) + +const DefaultQRSize = 3 + +type EnrollFactorParams struct { + FriendlyName string `json:"friendly_name"` + FactorType string `json:"factor_type"` + Issuer string `json:"issuer"` + Phone string `json:"phone"` +} + +type TOTPObject struct { + QRCode string `json:"qr_code,omitempty"` + Secret string `json:"secret,omitempty"` + URI string `json:"uri,omitempty"` +} + +type EnrollFactorResponse struct { + ID uuid.UUID `json:"id"` + Type string `json:"type"` + FriendlyName string `json:"friendly_name"` + TOTP *TOTPObject `json:"totp,omitempty"` + Phone string `json:"phone,omitempty"` +} + +type ChallengeFactorParams struct { + Channel string `json:"channel"` + WebAuthn *WebAuthnParams `json:"web_authn,omitempty"` +} + +type VerifyFactorParams struct { + ChallengeID uuid.UUID `json:"challenge_id"` + Code string `json:"code"` + WebAuthn *WebAuthnParams `json:"web_authn,omitempty"` +} + +type ChallengeFactorResponse struct { + ID uuid.UUID `json:"id"` + Type string `json:"type"` + ExpiresAt int64 `json:"expires_at,omitempty"` + CredentialRequestOptions *wbnprotocol.CredentialAssertion `json:"credential_request_options,omitempty"` + CredentialCreationOptions *wbnprotocol.CredentialCreation `json:"credential_creation_options,omitempty"` +} + +type UnenrollFactorResponse struct { + ID uuid.UUID `json:"id"` +} + +type WebAuthnParams struct { + RPID string `json:"rp_id,omitempty"` + // Can encode multiple origins as comma separated values like: "origin1,origin2" + RPOrigins string `json:"rp_origins,omitempty"` + AssertionResponse json.RawMessage `json:"assertion_response,omitempty"` + CreationResponse json.RawMessage `json:"creation_response,omitempty"` +} + +func (w *WebAuthnParams) GetRPOrigins() []string { + if w.RPOrigins == "" { + return nil + } + return strings.Split(w.RPOrigins, ",") +} + +func (w *WebAuthnParams) ToConfig() (*webauthn.WebAuthn, error) { + if w.RPID == "" { + return nil, fmt.Errorf("webAuthn RP ID cannot be empty") + } + + origins := w.GetRPOrigins() + if len(origins) == 0 { + return nil, fmt.Errorf("webAuthn RP Origins cannot be empty") + } + + var validOrigins []string + var invalidOrigins []string + + for _, origin := range origins { + parsedURL, err := url.Parse(origin) + if err != nil || (parsedURL.Scheme != "https" && !(parsedURL.Scheme == "http" && parsedURL.Hostname() == "localhost")) || parsedURL.Host == "" { + invalidOrigins = append(invalidOrigins, origin) + } else { + validOrigins = append(validOrigins, origin) + } + } + + if len(invalidOrigins) > 0 { + return nil, fmt.Errorf("invalid RP origins: %s", strings.Join(invalidOrigins, ", ")) + } + + wconfig := &webauthn.Config{ + // DisplayName is optional in spec but required to be non-empty in libary, we use the RPID as a placeholder. + RPDisplayName: w.RPID, + RPID: w.RPID, + RPOrigins: validOrigins, + } + + return webauthn.New(wconfig) +} + +const ( + QRCodeGenerationErrorMessage = "Error generating QR Code" +) + +func validateFactors(db *storage.Connection, user *models.User, newFactorName string, config *conf.GlobalConfiguration, session *models.Session) error { + if err := models.DeleteExpiredFactors(db, config.MFA.FactorExpiryDuration); err != nil { + return err + } + if err := db.Load(user, "Factors"); err != nil { + return err + } + factorCount := len(user.Factors) + numVerifiedFactors := 0 + + for _, factor := range user.Factors { + if factor.FriendlyName == newFactorName { + return unprocessableEntityError( + ErrorCodeMFAFactorNameConflict, + fmt.Sprintf("A factor with the friendly name %q for this user already exists", newFactorName), + ) + } + if factor.IsVerified() { + numVerifiedFactors++ + } + } + + if factorCount >= int(config.MFA.MaxEnrolledFactors) { + return unprocessableEntityError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") + } + + if numVerifiedFactors >= config.MFA.MaxVerifiedFactors { + return unprocessableEntityError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") + } + + if numVerifiedFactors > 0 && session != nil && !session.IsAAL2() { + return forbiddenError(ErrorCodeInsufficientAAL, "AAL2 required to enroll a new factor") + } + + return nil +} + +func (a *API) enrollPhoneFactor(w http.ResponseWriter, r *http.Request, params *EnrollFactorParams) error { + ctx := r.Context() + user := getUser(ctx) + session := getSession(ctx) + db := a.db.WithContext(ctx) + if params.Phone == "" { + return badRequestError(ErrorCodeValidationFailed, "Phone number required to enroll Phone factor") + } + + phone, err := validatePhone(params.Phone) + if err != nil { + return badRequestError(ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)") + } + + var factorsToDelete []models.Factor + for _, factor := range user.Factors { + if factor.IsPhoneFactor() && factor.Phone.String() == phone { + if factor.IsVerified() { + return unprocessableEntityError( + ErrorCodeMFAVerifiedFactorExists, + "A verified phone factor already exists, unenroll the existing factor to continue", + ) + } else if factor.IsUnverified() { + factorsToDelete = append(factorsToDelete, factor) + } + } + } + + if err := db.Destroy(&factorsToDelete); err != nil { + return internalServerError("Database error deleting unverified phone factors").WithInternalError(err) + } + + if err := validateFactors(db, user, params.FriendlyName, a.config, session); err != nil { + return err + } + + factor := models.NewPhoneFactor(user, phone, params.FriendlyName) + err = db.Transaction(func(tx *storage.Connection) error { + if terr := tx.Create(factor); terr != nil { + return terr + } + if terr := models.NewAuditLogEntry(r, tx, user, models.EnrollFactorAction, r.RemoteAddr, map[string]interface{}{ + "factor_id": factor.ID, + "factor_type": factor.FactorType, + }); terr != nil { + return terr + } + return nil + }) + if err != nil { + return err + } + return sendJSON(w, http.StatusOK, &EnrollFactorResponse{ + ID: factor.ID, + Type: models.Phone, + FriendlyName: factor.FriendlyName, + Phone: params.Phone, + }) +} + +func (a *API) enrollWebAuthnFactor(w http.ResponseWriter, r *http.Request, params *EnrollFactorParams) error { + ctx := r.Context() + user := getUser(ctx) + session := getSession(ctx) + db := a.db.WithContext(ctx) + + if err := validateFactors(db, user, params.FriendlyName, a.config, session); err != nil { + return err + } + + factor := models.NewWebAuthnFactor(user, params.FriendlyName) + err := db.Transaction(func(tx *storage.Connection) error { + if terr := tx.Create(factor); terr != nil { + return terr + } + if terr := models.NewAuditLogEntry(r, tx, user, models.EnrollFactorAction, r.RemoteAddr, map[string]interface{}{ + "factor_id": factor.ID, + "factor_type": factor.FactorType, + }); terr != nil { + return terr + } + return nil + }) + if err != nil { + return err + } + return sendJSON(w, http.StatusOK, &EnrollFactorResponse{ + ID: factor.ID, + Type: models.WebAuthn, + FriendlyName: factor.FriendlyName, + }) +} + +func (a *API) enrollTOTPFactor(w http.ResponseWriter, r *http.Request, params *EnrollFactorParams) error { + ctx := r.Context() + user := getUser(ctx) + db := a.db.WithContext(ctx) + config := a.config + session := getSession(ctx) + issuer := "" + if params.Issuer == "" { + u, err := url.ParseRequestURI(config.SiteURL) + if err != nil { + return internalServerError("site url is improperly formatted") + } + issuer = u.Host + } else { + issuer = params.Issuer + } + + if err := validateFactors(db, user, params.FriendlyName, config, session); err != nil { + return err + } + + var factor *models.Factor + var buf bytes.Buffer + var key *otp.Key + key, err := totp.Generate(totp.GenerateOpts{ + Issuer: issuer, + AccountName: user.GetEmail(), + }) + if err != nil { + return internalServerError(QRCodeGenerationErrorMessage).WithInternalError(err) + } + + svgData := svg.New(&buf) + qrCode, _ := qr.Encode(key.String(), qr.H, qr.Auto) + qs := goqrsvg.NewQrSVG(qrCode, DefaultQRSize) + qs.StartQrSVG(svgData) + if err = qs.WriteQrSVG(svgData); err != nil { + return internalServerError(QRCodeGenerationErrorMessage).WithInternalError(err) + } + svgData.End() + + factor = models.NewTOTPFactor(user, params.FriendlyName) + if err := factor.SetSecret(key.Secret(), config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey); err != nil { + return err + } + + err = db.Transaction(func(tx *storage.Connection) error { + if terr := tx.Create(factor); terr != nil { + return terr + } + + if terr := models.NewAuditLogEntry(r, tx, user, models.EnrollFactorAction, r.RemoteAddr, map[string]interface{}{ + "factor_id": factor.ID, + }); terr != nil { + return terr + } + return nil + }) + if err != nil { + return err + } + return sendJSON(w, http.StatusOK, &EnrollFactorResponse{ + ID: factor.ID, + Type: models.TOTP, + FriendlyName: factor.FriendlyName, + TOTP: &TOTPObject{ + // See: https://css-tricks.com/probably-dont-base64-svg/ + QRCode: buf.String(), + Secret: key.Secret(), + URI: key.URL(), + }, + }) +} + +func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + user := getUser(ctx) + session := getSession(ctx) + config := a.config + + if session == nil || user == nil { + return internalServerError("A valid session and a registered user are required to enroll a factor") + } + params := &EnrollFactorParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + switch params.FactorType { + case models.Phone: + if !config.MFA.Phone.EnrollEnabled { + return unprocessableEntityError(ErrorCodeMFAPhoneEnrollDisabled, "MFA enroll is disabled for Phone") + } + return a.enrollPhoneFactor(w, r, params) + case models.TOTP: + if !config.MFA.TOTP.EnrollEnabled { + return unprocessableEntityError(ErrorCodeMFATOTPEnrollDisabled, "MFA enroll is disabled for TOTP") + } + return a.enrollTOTPFactor(w, r, params) + case models.WebAuthn: + if !config.MFA.WebAuthn.EnrollEnabled { + return unprocessableEntityError(ErrorCodeMFAWebAuthnEnrollDisabled, "MFA enroll is disabled for WebAuthn") + } + return a.enrollWebAuthnFactor(w, r, params) + default: + return badRequestError(ErrorCodeValidationFailed, "factor_type needs to be totp, phone, or webauthn") + } + +} + +func (a *API) challengePhoneFactor(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + config := a.config + db := a.db.WithContext(ctx) + user := getUser(ctx) + factor := getFactor(ctx) + ipAddress := utilities.GetIPAddress(r) + params := &ChallengeFactorParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + channel := params.Channel + if channel == "" { + channel = sms_provider.SMSProvider + } + if !sms_provider.IsValidMessageChannel(channel, config) { + return badRequestError(ErrorCodeValidationFailed, InvalidChannelError) + } + + if factor.IsPhoneFactor() && factor.LastChallengedAt != nil { + if !factor.LastChallengedAt.Add(config.MFA.Phone.MaxFrequency).Before(time.Now()) { + return tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, generateFrequencyLimitErrorMessage(factor.LastChallengedAt, config.MFA.Phone.MaxFrequency)) + } + } + + otp := crypto.GenerateOtp(config.MFA.Phone.OtpLength) + + challenge, err := factor.CreatePhoneChallenge(ipAddress, otp, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey) + if err != nil { + return internalServerError("error creating SMS Challenge") + } + + message, err := generateSMSFromTemplate(config.MFA.Phone.SMSTemplate, otp) + if err != nil { + return internalServerError("error generating sms template").WithInternalError(err) + } + + if config.Hook.SendSMS.Enabled { + input := hooks.SendSMSInput{ + User: user, + SMS: hooks.SMS{ + OTP: otp, + SMSType: "mfa", + }, + } + output := hooks.SendSMSOutput{} + err := a.invokeHook(a.db, r, &input, &output) + if err != nil { + return internalServerError("error invoking hook") + } + } else { + smsProvider, err := sms_provider.GetSmsProvider(*config) + if err != nil { + return internalServerError("Failed to get SMS provider").WithInternalError(err) + } + // We omit messageID for now, can consider reinstating if there are requests. + if _, err = smsProvider.SendMessage(factor.Phone.String(), message, channel, otp); err != nil { + return internalServerError("error sending message").WithInternalError(err) + } + } + if err := db.Transaction(func(tx *storage.Connection) error { + if terr := factor.WriteChallengeToDatabase(tx, challenge); terr != nil { + return terr + } + + if terr := models.NewAuditLogEntry(r, tx, user, models.CreateChallengeAction, r.RemoteAddr, map[string]interface{}{ + "factor_id": factor.ID, + "factor_status": factor.Status, + }); terr != nil { + return terr + } + return nil + }); err != nil { + return err + } + return sendJSON(w, http.StatusOK, &ChallengeFactorResponse{ + ID: challenge.ID, + Type: factor.FactorType, + ExpiresAt: challenge.GetExpiryTime(config.MFA.ChallengeExpiryDuration).Unix(), + }) +} + +func (a *API) challengeTOTPFactor(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + config := a.config + db := a.db.WithContext(ctx) + + user := getUser(ctx) + factor := getFactor(ctx) + ipAddress := utilities.GetIPAddress(r) + + challenge := factor.CreateChallenge(ipAddress) + + if err := db.Transaction(func(tx *storage.Connection) error { + if terr := factor.WriteChallengeToDatabase(tx, challenge); terr != nil { + return terr + } + if terr := models.NewAuditLogEntry(r, tx, user, models.CreateChallengeAction, r.RemoteAddr, map[string]interface{}{ + "factor_id": factor.ID, + "factor_status": factor.Status, + }); terr != nil { + return terr + } + return nil + }); err != nil { + return err + } + + return sendJSON(w, http.StatusOK, &ChallengeFactorResponse{ + ID: challenge.ID, + Type: factor.FactorType, + ExpiresAt: challenge.GetExpiryTime(config.MFA.ChallengeExpiryDuration).Unix(), + }) +} + +func (a *API) challengeWebAuthnFactor(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + config := a.config + + user := getUser(ctx) + factor := getFactor(ctx) + ipAddress := utilities.GetIPAddress(r) + + params := &ChallengeFactorParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + if params.WebAuthn == nil { + return badRequestError(ErrorCodeValidationFailed, "web_authn config required") + } + webAuthn, err := params.WebAuthn.ToConfig() + if err != nil { + return err + } + var response *ChallengeFactorResponse + var ws *models.WebAuthnSessionData + var challenge *models.Challenge + if factor.IsUnverified() { + options, session, err := webAuthn.BeginRegistration(user) + if err != nil { + return internalServerError("Failed to generate WebAuthn registration data").WithInternalError(err) + } + ws = &models.WebAuthnSessionData{ + SessionData: session, + } + challenge = ws.ToChallenge(factor.ID, ipAddress) + + response = &ChallengeFactorResponse{ + CredentialCreationOptions: options, + Type: factor.FactorType, + ID: challenge.ID, + } + + } else if factor.IsVerified() { + options, session, err := webAuthn.BeginLogin(user) + if err != nil { + return err + } + ws = &models.WebAuthnSessionData{ + SessionData: session, + } + challenge = ws.ToChallenge(factor.ID, ipAddress) + response = &ChallengeFactorResponse{ + CredentialRequestOptions: options, + Type: factor.FactorType, + ID: challenge.ID, + } + + } + + if err := factor.WriteChallengeToDatabase(db, challenge); err != nil { + return err + } + response.ExpiresAt = challenge.GetExpiryTime(config.MFA.ChallengeExpiryDuration).Unix() + + return sendJSON(w, http.StatusOK, response) + +} + +func (a *API) validateChallenge(r *http.Request, db *storage.Connection, factor *models.Factor, challengeID uuid.UUID) (*models.Challenge, error) { + config := a.config + currentIP := utilities.GetIPAddress(r) + + challenge, err := factor.FindChallengeByID(db, challengeID) + if err != nil { + if models.IsNotFoundError(err) { + return nil, unprocessableEntityError(ErrorCodeMFAFactorNotFound, "MFA factor with the provided challenge ID not found") + } + return nil, internalServerError("Database error finding Challenge").WithInternalError(err) + } + + if challenge.VerifiedAt != nil || challenge.IPAddress != currentIP { + return nil, unprocessableEntityError(ErrorCodeMFAIPAddressMismatch, "Challenge and verify IP addresses mismatch.") + } + + if challenge.HasExpired(config.MFA.ChallengeExpiryDuration) { + if err := db.Destroy(challenge); err != nil { + return nil, internalServerError("Database error deleting challenge").WithInternalError(err) + } + return nil, unprocessableEntityError(ErrorCodeMFAChallengeExpired, "MFA challenge %v has expired, verify against another challenge or create a new challenge.", challenge.ID) + } + + return challenge, nil +} + +func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + config := a.config + factor := getFactor(ctx) + + switch factor.FactorType { + case models.Phone: + if !config.MFA.Phone.VerifyEnabled { + return unprocessableEntityError(ErrorCodeMFAPhoneVerifyDisabled, "MFA verification is disabled for Phone") + } + return a.challengePhoneFactor(w, r) + + case models.TOTP: + if !config.MFA.TOTP.VerifyEnabled { + return unprocessableEntityError(ErrorCodeMFATOTPVerifyDisabled, "MFA verification is disabled for TOTP") + } + return a.challengeTOTPFactor(w, r) + case models.WebAuthn: + if !config.MFA.WebAuthn.VerifyEnabled { + return unprocessableEntityError(ErrorCodeMFAWebAuthnVerifyDisabled, "MFA verification is disabled for WebAuthn") + } + return a.challengeWebAuthnFactor(w, r) + default: + return badRequestError(ErrorCodeValidationFailed, "factor_type needs to be totp, phone, or webauthn") + } + +} + +func (a *API) verifyTOTPFactor(w http.ResponseWriter, r *http.Request, params *VerifyFactorParams) error { + var err error + ctx := r.Context() + user := getUser(ctx) + factor := getFactor(ctx) + config := a.config + db := a.db.WithContext(ctx) + + challenge, err := a.validateChallenge(r, db, factor, params.ChallengeID) + if err != nil { + return err + } + + secret, shouldReEncrypt, err := factor.GetSecret(config.Security.DBEncryption.DecryptionKeys, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID) + if err != nil { + return internalServerError("Database error verifying MFA TOTP secret").WithInternalError(err) + } + + valid, verr := totp.ValidateCustom(params.Code, secret, time.Now().UTC(), totp.ValidateOpts{ + Period: 30, + Skew: 1, + Digits: otp.DigitsSix, + Algorithm: otp.AlgorithmSHA1, + }) + + if config.Hook.MFAVerificationAttempt.Enabled { + input := hooks.MFAVerificationAttemptInput{ + UserID: user.ID, + FactorID: factor.ID, + Valid: valid, + } + + output := hooks.MFAVerificationAttemptOutput{} + err := a.invokeHook(nil, r, &input, &output) + if err != nil { + return err + } + + if output.Decision == hooks.HookRejection { + if err := models.Logout(db, user.ID); err != nil { + return err + } + + if output.Message == "" { + output.Message = hooks.DefaultMFAHookRejectionMessage + } + + return forbiddenError(ErrorCodeMFAVerificationRejected, output.Message) + } + } + if !valid { + if shouldReEncrypt && config.Security.DBEncryption.Encrypt { + if err := factor.SetSecret(secret, true, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey); err != nil { + return err + } + + if err := db.UpdateOnly(factor, "secret"); err != nil { + return err + } + } + return unprocessableEntityError(ErrorCodeMFAVerificationFailed, "Invalid TOTP code entered").WithInternalError(verr) + } + + var token *AccessTokenResponse + + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + if terr = models.NewAuditLogEntry(r, tx, user, models.VerifyFactorAction, r.RemoteAddr, map[string]interface{}{ + "factor_id": factor.ID, + "challenge_id": challenge.ID, + "factor_type": factor.FactorType, + }); terr != nil { + return terr + } + if terr = challenge.Verify(tx); terr != nil { + return terr + } + if !factor.IsVerified() { + if terr = factor.UpdateStatus(tx, models.FactorStateVerified); terr != nil { + return terr + } + } + if shouldReEncrypt && config.Security.DBEncryption.Encrypt { + es, terr := crypto.NewEncryptedString(factor.ID.String(), []byte(secret), config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey) + if terr != nil { + return terr + } + + factor.Secret = es.String() + if terr := tx.UpdateOnly(factor, "secret"); terr != nil { + return terr + } + } + user, terr = models.FindUserByID(tx, user.ID) + if terr != nil { + return terr + } + + token, terr = a.updateMFASessionAndClaims(r, tx, user, models.TOTPSignIn, models.GrantParams{ + FactorID: &factor.ID, + }) + if terr != nil { + return terr + } + if terr = models.InvalidateSessionsWithAALLessThan(tx, user.ID, models.AAL2.String()); terr != nil { + return internalServerError("Failed to update sessions. %s", terr) + } + if terr = models.DeleteUnverifiedFactors(tx, user, factor.FactorType); terr != nil { + return internalServerError("Error removing unverified factors. %s", terr) + } + return nil + }) + if err != nil { + return err + } + metering.RecordLogin(string(models.MFACodeLoginAction), user.ID) + + return sendJSON(w, http.StatusOK, token) + +} + +func (a *API) verifyPhoneFactor(w http.ResponseWriter, r *http.Request, params *VerifyFactorParams) error { + ctx := r.Context() + config := a.config + user := getUser(ctx) + factor := getFactor(ctx) + db := a.db.WithContext(ctx) + currentIP := utilities.GetIPAddress(r) + + challenge, err := a.validateChallenge(r, db, factor, params.ChallengeID) + if err != nil { + return err + } + + if challenge.VerifiedAt != nil || challenge.IPAddress != currentIP { + return unprocessableEntityError(ErrorCodeMFAIPAddressMismatch, "Challenge and verify IP addresses mismatch") + } + + if challenge.HasExpired(config.MFA.ChallengeExpiryDuration) { + if err := db.Destroy(challenge); err != nil { + return internalServerError("Database error deleting challenge").WithInternalError(err) + } + return unprocessableEntityError(ErrorCodeMFAChallengeExpired, "MFA challenge %v has expired, verify against another challenge or create a new challenge.", challenge.ID) + } + var valid bool + var otpCode string + var shouldReEncrypt bool + if config.Sms.IsTwilioVerifyProvider() { + smsProvider, err := sms_provider.GetSmsProvider(*config) + if err != nil { + return internalServerError("Failed to get SMS provider").WithInternalError(err) + } + if err := smsProvider.VerifyOTP(factor.Phone.String(), params.Code); err != nil { + return forbiddenError(ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) + } + valid = true + } else { + otpCode, shouldReEncrypt, err = challenge.GetOtpCode(config.Security.DBEncryption.DecryptionKeys, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID) + if err != nil { + return internalServerError("Database error verifying MFA TOTP secret").WithInternalError(err) + } + valid = subtle.ConstantTimeCompare([]byte(otpCode), []byte(params.Code)) == 1 + } + if config.Hook.MFAVerificationAttempt.Enabled { + input := hooks.MFAVerificationAttemptInput{ + UserID: user.ID, + FactorID: factor.ID, + FactorType: factor.FactorType, + Valid: valid, + } + + output := hooks.MFAVerificationAttemptOutput{} + err := a.invokeHook(nil, r, &input, &output) + if err != nil { + return err + } + + if output.Decision == hooks.HookRejection { + if err := models.Logout(db, user.ID); err != nil { + return err + } + + if output.Message == "" { + output.Message = hooks.DefaultMFAHookRejectionMessage + } + + return forbiddenError(ErrorCodeMFAVerificationRejected, output.Message) + } + } + if !valid { + if shouldReEncrypt && config.Security.DBEncryption.Encrypt { + if err := challenge.SetOtpCode(otpCode, true, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey); err != nil { + return err + } + + if err := db.UpdateOnly(challenge, "otp_code"); err != nil { + return err + } + } + return unprocessableEntityError(ErrorCodeMFAVerificationFailed, "Invalid MFA Phone code entered") + } + + var token *AccessTokenResponse + + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + if terr = models.NewAuditLogEntry(r, tx, user, models.VerifyFactorAction, r.RemoteAddr, map[string]interface{}{ + "factor_id": factor.ID, + "challenge_id": challenge.ID, + "factor_type": factor.FactorType, + }); terr != nil { + return terr + } + if terr = challenge.Verify(tx); terr != nil { + return terr + } + if !factor.IsVerified() { + if terr = factor.UpdateStatus(tx, models.FactorStateVerified); terr != nil { + return terr + } + } + user, terr = models.FindUserByID(tx, user.ID) + if terr != nil { + return terr + } + + token, terr = a.updateMFASessionAndClaims(r, tx, user, models.MFAPhone, models.GrantParams{ + FactorID: &factor.ID, + }) + if terr != nil { + return terr + } + if terr = models.InvalidateSessionsWithAALLessThan(tx, user.ID, models.AAL2.String()); terr != nil { + return internalServerError("Failed to update sessions. %s", terr) + } + if terr = models.DeleteUnverifiedFactors(tx, user, factor.FactorType); terr != nil { + return internalServerError("Error removing unverified factors. %s", terr) + } + return nil + }) + if err != nil { + return err + } + metering.RecordLogin(string(models.MFACodeLoginAction), user.ID) + + return sendJSON(w, http.StatusOK, token) +} + +func (a *API) verifyWebAuthnFactor(w http.ResponseWriter, r *http.Request, params *VerifyFactorParams) error { + ctx := r.Context() + user := getUser(ctx) + factor := getFactor(ctx) + db := a.db.WithContext(ctx) + + var webAuthn *webauthn.WebAuthn + var credential *webauthn.Credential + var err error + + switch { + case params.WebAuthn == nil: + return badRequestError(ErrorCodeValidationFailed, "WebAuthn config required") + case factor.IsVerified() && params.WebAuthn.AssertionResponse == nil: + return badRequestError(ErrorCodeValidationFailed, "creation_response required to login") + case factor.IsUnverified() && params.WebAuthn.CreationResponse == nil: + return badRequestError(ErrorCodeValidationFailed, "assertion_response required to login") + default: + webAuthn, err = params.WebAuthn.ToConfig() + if err != nil { + return err + } + } + + challenge, err := a.validateChallenge(r, db, factor, params.ChallengeID) + if err != nil { + return err + } + webAuthnSession := *challenge.WebAuthnSessionData.SessionData + // Once the challenge is validated, we consume the challenge + if err := db.Destroy(challenge); err != nil { + return internalServerError("Database error deleting challenge").WithInternalError(err) + } + + if factor.IsUnverified() { + parsedResponse, err := wbnprotocol.ParseCredentialCreationResponseBody(bytes.NewReader(params.WebAuthn.CreationResponse)) + if err != nil { + return badRequestError(ErrorCodeValidationFailed, "Invalid credential_creation_response") + } + credential, err = webAuthn.CreateCredential(user, webAuthnSession, parsedResponse) + if err != nil { + return err + } + + } else if factor.IsVerified() { + parsedResponse, err := wbnprotocol.ParseCredentialRequestResponseBody(bytes.NewReader(params.WebAuthn.AssertionResponse)) + if err != nil { + return badRequestError(ErrorCodeValidationFailed, "Invalid credential_request_response") + } + credential, err = webAuthn.ValidateLogin(user, webAuthnSession, parsedResponse) + if err != nil { + return internalServerError("Failed to validate WebAuthn MFA response").WithInternalError(err) + } + } + var token *AccessTokenResponse + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + if terr = models.NewAuditLogEntry(r, tx, user, models.VerifyFactorAction, r.RemoteAddr, map[string]interface{}{ + "factor_id": factor.ID, + "challenge_id": challenge.ID, + "factor_type": factor.FactorType, + }); terr != nil { + return terr + } + // Challenge verification not needed as the challenge is destroyed on use + if !factor.IsVerified() { + if terr = factor.UpdateStatus(tx, models.FactorStateVerified); terr != nil { + return terr + } + if terr = factor.SaveWebAuthnCredential(tx, credential); terr != nil { + return terr + } + } + user, terr = models.FindUserByID(tx, user.ID) + if terr != nil { + return terr + } + token, terr = a.updateMFASessionAndClaims(r, tx, user, models.MFAWebAuthn, models.GrantParams{ + FactorID: &factor.ID, + }) + if terr != nil { + return terr + } + if terr = models.InvalidateSessionsWithAALLessThan(tx, user.ID, models.AAL2.String()); terr != nil { + return internalServerError("Failed to update session").WithInternalError(terr) + } + if terr = models.DeleteUnverifiedFactors(tx, user, models.WebAuthn); terr != nil { + return internalServerError("Failed to remove unverified MFA WebAuthn factors").WithInternalError(terr) + } + return nil + }) + if err != nil { + return err + } + metering.RecordLogin(string(models.MFACodeLoginAction), user.ID) + + return sendJSON(w, http.StatusOK, token) +} + +func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + factor := getFactor(ctx) + config := a.config + + params := &VerifyFactorParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + if params.Code == "" && factor.FactorType != models.WebAuthn { + return badRequestError(ErrorCodeValidationFailed, "Code needs to be non-empty") + } + + switch factor.FactorType { + case models.Phone: + if !config.MFA.Phone.VerifyEnabled { + return unprocessableEntityError(ErrorCodeMFAPhoneVerifyDisabled, "MFA verification is disabled for Phone") + } + + return a.verifyPhoneFactor(w, r, params) + case models.TOTP: + if !config.MFA.TOTP.VerifyEnabled { + return unprocessableEntityError(ErrorCodeMFATOTPVerifyDisabled, "MFA verification is disabled for TOTP") + } + return a.verifyTOTPFactor(w, r, params) + case models.WebAuthn: + if !config.MFA.WebAuthn.VerifyEnabled { + return unprocessableEntityError(ErrorCodeMFAWebAuthnEnrollDisabled, "MFA verification is disabled for WebAuthn") + } + return a.verifyWebAuthnFactor(w, r, params) + default: + return badRequestError(ErrorCodeValidationFailed, "factor_type needs to be totp, phone, or webauthn") + } + +} + +func (a *API) UnenrollFactor(w http.ResponseWriter, r *http.Request) error { + var err error + ctx := r.Context() + user := getUser(ctx) + factor := getFactor(ctx) + session := getSession(ctx) + db := a.db.WithContext(ctx) + + if factor == nil || session == nil || user == nil { + return internalServerError("A valid session and factor are required to unenroll a factor") + } + + if factor.IsVerified() && !session.IsAAL2() { + return unprocessableEntityError(ErrorCodeInsufficientAAL, "AAL2 required to unenroll verified factor") + } + + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + if terr := tx.Destroy(factor); terr != nil { + return terr + } + if terr = models.NewAuditLogEntry(r, tx, user, models.UnenrollFactorAction, r.RemoteAddr, map[string]interface{}{ + "factor_id": factor.ID, + "factor_status": factor.Status, + "session_id": session.ID, + }); terr != nil { + return terr + } + if terr = factor.DowngradeSessionsToAAL1(tx); terr != nil { + return terr + } + return nil + }) + if err != nil { + return err + } + + return sendJSON(w, http.StatusOK, &UnenrollFactorResponse{ + ID: factor.ID, + }) +} diff --git a/auth_v2.169.0/internal/api/mfa_test.go b/auth_v2.169.0/internal/api/mfa_test.go new file mode 100644 index 0000000..653f38f --- /dev/null +++ b/auth_v2.169.0/internal/api/mfa_test.go @@ -0,0 +1,1011 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gofrs/uuid" + + "github.com/pquerna/otp" + "github.com/supabase/auth/internal/api/sms_provider" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/utilities" + + "github.com/pquerna/otp/totp" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type MFATestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration + TestDomain string + TestEmail string + TestOTPKey *otp.Key + TestPassword string + TestUser *models.User + TestSession *models.Session + TestSecondarySession *models.Session +} + +func TestMFA(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + ts := &MFATestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + suite.Run(t, ts) +} + +func (ts *MFATestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + ts.TestEmail = "test@example.com" + ts.TestPassword = "password" + // Create user + u, err := models.NewUser("123456789", ts.TestEmail, ts.TestPassword, ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") + // Create Factor + f := models.NewTOTPFactor(u, "test_factor") + require.NoError(ts.T(), f.SetSecret("secretkey", ts.Config.Security.DBEncryption.Encrypt, ts.Config.Security.DBEncryption.EncryptionKeyID, ts.Config.Security.DBEncryption.EncryptionKey)) + require.NoError(ts.T(), ts.API.db.Create(f), "Error saving new test factor") + // Create corresponding session + s, err := models.NewSession(u.ID, &f.ID) + require.NoError(ts.T(), err, "Error creating test session") + require.NoError(ts.T(), ts.API.db.Create(s), "Error saving test session") + + u, err = models.FindUserByEmailAndAudience(ts.API.db, ts.TestEmail, ts.Config.JWT.Aud) + ts.Require().NoError(err) + + ts.TestUser = u + ts.TestSession = s + + secondarySession, err := models.NewSession(ts.TestUser.ID, &f.ID) + require.NoError(ts.T(), err, "Error creating test session") + require.NoError(ts.T(), ts.API.db.Create(secondarySession), "Error saving test session") + + ts.TestSecondarySession = secondarySession + + // Generate TOTP related settings + testDomain := strings.Split(ts.TestEmail, "@")[1] + ts.TestDomain = testDomain + + // By default MFA Phone is disabled + ts.Config.MFA.Phone.EnrollEnabled = true + ts.Config.MFA.Phone.VerifyEnabled = true + + ts.Config.MFA.WebAuthn.EnrollEnabled = true + ts.Config.MFA.WebAuthn.VerifyEnabled = true + + key, err := totp.Generate(totp.GenerateOpts{ + Issuer: ts.TestDomain, + AccountName: ts.TestEmail, + }) + require.NoError(ts.T(), err) + ts.TestOTPKey = key + +} + +func (ts *MFATestSuite) generateAAL1Token(user *models.User, sessionId *uuid.UUID) string { + // Not an actual path. Dummy request to simulate a signup request that we can use in generateAccessToken + req := httptest.NewRequest(http.MethodPost, "/factors", nil) + token, _, err := ts.API.generateAccessToken(req, ts.API.db, user, sessionId, models.TOTPSignIn) + require.NoError(ts.T(), err, "Error generating access token") + return token +} + +func (ts *MFATestSuite) TestEnrollFactor() { + testFriendlyName := "bob" + alternativeFriendlyName := "john" + + token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) + + var cases = []struct { + desc string + friendlyName string + factorType string + issuer string + phone string + expectedCode int + }{ + { + desc: "TOTP: No issuer", + friendlyName: alternativeFriendlyName, + factorType: models.TOTP, + issuer: "", + phone: "", + expectedCode: http.StatusOK, + }, + { + desc: "Invalid factor type", + friendlyName: testFriendlyName, + factorType: "invalid_factor", + issuer: ts.TestDomain, + phone: "", + expectedCode: http.StatusBadRequest, + }, + { + desc: "TOTP: Factor has friendly name", + friendlyName: testFriendlyName, + factorType: models.TOTP, + issuer: ts.TestDomain, + phone: "", + expectedCode: http.StatusOK, + }, + { + desc: "TOTP: Enrolling without friendly name", + friendlyName: "", + factorType: models.TOTP, + issuer: ts.TestDomain, + phone: "", + expectedCode: http.StatusOK, + }, + { + desc: "Phone: Enroll with friendly name", + friendlyName: "phone_factor", + factorType: models.Phone, + phone: "+12345677889", + expectedCode: http.StatusOK, + }, + { + desc: "Phone: Enroll with invalid phone number", + friendlyName: "phone_factor", + factorType: models.Phone, + phone: "+1", + expectedCode: http.StatusBadRequest, + }, + { + desc: "Phone: Enroll without phone number should return error", + friendlyName: "phone_factor_fail", + factorType: models.Phone, + phone: "", + expectedCode: http.StatusBadRequest, + }, + { + desc: "WebAuthn: Enroll with friendly name", + friendlyName: "webauthn_factor", + factorType: models.WebAuthn, + expectedCode: http.StatusOK, + }, + } + for _, c := range cases { + ts.Run(c.desc, func() { + w := performEnrollFlow(ts, token, c.friendlyName, c.factorType, c.issuer, c.phone, c.expectedCode) + enrollResp := EnrollFactorResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&enrollResp)) + + if c.expectedCode == http.StatusOK { + addedFactor, err := models.FindFactorByFactorID(ts.API.db, enrollResp.ID) + require.NoError(ts.T(), err) + require.False(ts.T(), addedFactor.IsVerified()) + + if c.friendlyName != "" { + require.Equal(ts.T(), c.friendlyName, addedFactor.FriendlyName) + } + + if c.factorType == models.TOTP { + qrCode := enrollResp.TOTP.QRCode + hasSVGStartAndEnd := strings.Contains(qrCode, "") + require.True(ts.T(), hasSVGStartAndEnd) + require.Equal(ts.T(), c.friendlyName, enrollResp.FriendlyName) + } + } + + }) + } +} + +func (ts *MFATestSuite) TestDuplicateEnrollPhoneFactor() { + testPhoneNumber := "+12345677889" + altPhoneNumber := "+987412444444" + friendlyName := "phone_factor" + altFriendlyName := "alt_phone_factor" + token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) + + var cases = []struct { + desc string + earlierFactorName string + laterFactorName string + phone string + secondPhone string + expectedCode int + expectedNumberOfFactors int + }{ + { + desc: "Phone: Only the latest factor should persist when enrolling two unverified phone factors with the same number", + earlierFactorName: friendlyName, + laterFactorName: altFriendlyName, + phone: testPhoneNumber, + secondPhone: testPhoneNumber, + expectedNumberOfFactors: 1, + }, + + { + desc: "Phone: Both factors should persist when enrolling two different unverified numbers", + earlierFactorName: friendlyName, + laterFactorName: altFriendlyName, + phone: testPhoneNumber, + secondPhone: altPhoneNumber, + expectedNumberOfFactors: 2, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + // Delete all test factors to start from clean slate + require.NoError(ts.T(), ts.API.db.Destroy(ts.TestUser.Factors)) + _ = performEnrollFlow(ts, token, c.earlierFactorName, models.Phone, ts.TestDomain, c.phone, http.StatusOK) + + w := performEnrollFlow(ts, token, c.laterFactorName, models.Phone, ts.TestDomain, c.secondPhone, http.StatusOK) + enrollResp := EnrollFactorResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&enrollResp)) + + laterFactor, err := models.FindFactorByFactorID(ts.API.db, enrollResp.ID) + require.NoError(ts.T(), err) + require.False(ts.T(), laterFactor.IsVerified()) + + require.NoError(ts.T(), ts.API.db.Eager("Factors").Find(ts.TestUser, ts.TestUser.ID)) + require.Equal(ts.T(), len(ts.TestUser.Factors), c.expectedNumberOfFactors) + + }) + } +} + +func (ts *MFATestSuite) TestDuplicateEnrollPhoneFactorWithVerified() { + testPhoneNumber := "+12345677889" + friendlyName := "phone_factor" + altFriendlyName := "alt_phone_factor" + token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) + + ts.Run("Phone: Enrolling a factor with the same number as an existing verified phone factor should result in an error", func() { + require.NoError(ts.T(), ts.API.db.Destroy(ts.TestUser.Factors)) + + // Setup verified factor + w := performEnrollFlow(ts, token, friendlyName, models.Phone, ts.TestDomain, testPhoneNumber, http.StatusOK) + enrollResp := EnrollFactorResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&enrollResp)) + firstFactor, err := models.FindFactorByFactorID(ts.API.db, enrollResp.ID) + require.NoError(ts.T(), err) + require.NoError(ts.T(), firstFactor.UpdateStatus(ts.API.db, models.FactorStateVerified)) + + expectedStatusCode := http.StatusUnprocessableEntity + _ = performEnrollFlow(ts, token, altFriendlyName, models.Phone, ts.TestDomain, testPhoneNumber, expectedStatusCode) + + require.NoError(ts.T(), ts.API.db.Eager("Factors").Find(ts.TestUser, ts.TestUser.ID)) + require.Equal(ts.T(), len(ts.TestUser.Factors), 1) + }) +} + +func (ts *MFATestSuite) TestDuplicateTOTPEnrollsReturnExpectedMessage() { + friendlyName := "mary" + issuer := "https://issuer.com" + token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) + _ = performEnrollFlow(ts, token, friendlyName, models.TOTP, issuer, "", http.StatusOK) + response := performEnrollFlow(ts, token, friendlyName, models.TOTP, issuer, "", http.StatusUnprocessableEntity) + + var errorResponse HTTPError + err := json.NewDecoder(response.Body).Decode(&errorResponse) + require.NoError(ts.T(), err) + + require.Contains(ts.T(), errorResponse.ErrorCode, ErrorCodeMFAFactorNameConflict) +} + +func (ts *MFATestSuite) AAL2RequiredToUpdatePasswordAfterEnrollment() { + resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword, true /* <- requireStatusOK */) + accessTokenResp := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp)) + + var w *httptest.ResponseRecorder + var buffer bytes.Buffer + token := accessTokenResp.Token + // Update Password to new password + newPassword := "newpass" + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "password": newPassword, + })) + + req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) + req.Header.Set("Content-Type", "application/json") + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + // Logout + reqURL := "http://localhost/logout" + req = httptest.NewRequest(http.MethodPost, reqURL, nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusNoContent, w.Code) + + // Get AAL1 token + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": ts.TestEmail, + "password": newPassword, + })) + + req = httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + session1 := AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&session1)) + + // Update Password again, this should fail + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "password": ts.TestPassword, + })) + + req = httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) + req.Header.Set("Content-Type", "application/json") + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", session1.Token)) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusUnauthorized, w.Code) + +} + +func (ts *MFATestSuite) TestMultipleEnrollsCleanupExpiredFactors() { + // All factors are deleted when a subsequent enroll is made + ts.API.config.MFA.FactorExpiryDuration = 0 * time.Second + // Verified factor should not be deleted (Factor 1) + resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword, true /* <- requireStatusOK */) + numFactors := 5 + accessTokenResp := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp)) + + var w *httptest.ResponseRecorder + token := accessTokenResp.Token + for i := 0; i < numFactors; i++ { + w = performEnrollFlow(ts, token, "first-name", models.TOTP, "https://issuer.com", "", http.StatusOK) + } + + enrollResp := EnrollFactorResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&enrollResp)) + + // Make a challenge so last, unverified factor isn't deleted on next enroll (Factor 2) + _ = performChallengeFlow(ts, enrollResp.ID, token) + + // Enroll another Factor (Factor 3) + _ = performEnrollFlow(ts, token, "second-name", models.TOTP, "https://issuer.com", "", http.StatusOK) + require.NoError(ts.T(), ts.API.db.Eager("Factors").Find(ts.TestUser, ts.TestUser.ID)) + require.Equal(ts.T(), 3, len(ts.TestUser.Factors)) +} + +func (ts *MFATestSuite) TestChallengeTOTPFactor() { + // Test Factor is a TOTP Factor + f := ts.TestUser.Factors[0] + token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) + w := performChallengeFlow(ts, f.ID, token) + challengeResp := ChallengeFactorResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&challengeResp)) + + require.Equal(ts.T(), http.StatusOK, w.Code) + require.Equal(ts.T(), challengeResp.Type, models.TOTP) + +} + +func (ts *MFATestSuite) TestChallengeSMSFactor() { + // Challenge should still work with phone provider disabled + ts.Config.External.Phone.Enabled = false + ts.Config.Hook.SendSMS.Enabled = true + ts.Config.Hook.SendSMS.URI = "pg-functions://postgres/auth/send_sms_mfa_mock" + + ts.Config.MFA.Phone.MaxFrequency = 0 * time.Second + + require.NoError(ts.T(), ts.Config.Hook.SendSMS.PopulateExtensibilityPoint()) + require.NoError(ts.T(), ts.API.db.RawQuery(` + create or replace function send_sms_mfa_mock(input jsonb) + returns json as $$ + begin + return input; + end; $$ language plpgsql;`).Exec()) + + phone := "+1234567" + friendlyName := "testchallengesmsfactor" + + f := models.NewPhoneFactor(ts.TestUser, phone, friendlyName) + require.NoError(ts.T(), ts.API.db.Create(f), "Error creating new SMS factor") + token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) + + var cases = []struct { + desc string + channel string + expectedCode int + }{ + { + desc: "SMS Channel", + channel: sms_provider.SMSProvider, + expectedCode: http.StatusOK, + }, + { + desc: "WhatsApp Channel", + channel: sms_provider.WhatsappProvider, + expectedCode: http.StatusOK, + }, + } + + for _, tc := range cases { + ts.Run(tc.desc, func() { + w := performSMSChallengeFlow(ts, f.ID, token, tc.channel) + challengeResp := ChallengeFactorResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&challengeResp)) + require.Equal(ts.T(), challengeResp.Type, models.Phone) + require.Equal(ts.T(), tc.expectedCode, w.Code, tc.desc) + }) + } +} + +func (ts *MFATestSuite) TestMFAVerifyFactor() { + cases := []struct { + desc string + validChallenge bool + validCode bool + factorType string + expectedHTTPCode int + }{ + { + desc: "Invalid: Valid code and expired challenge", + validChallenge: false, + validCode: true, + factorType: models.TOTP, + expectedHTTPCode: http.StatusUnprocessableEntity, + }, + { + desc: "Invalid: Invalid code and valid challenge", + validChallenge: true, + validCode: false, + factorType: models.TOTP, + expectedHTTPCode: http.StatusUnprocessableEntity, + }, + { + desc: "Valid /verify request", + validChallenge: true, + validCode: true, + factorType: models.TOTP, + expectedHTTPCode: http.StatusOK, + }, + { + desc: "Invalid: Valid code and expired challenge (SMS)", + validChallenge: false, + validCode: true, + factorType: models.Phone, + expectedHTTPCode: http.StatusUnprocessableEntity, + }, + { + desc: "Invalid: Invalid code and valid challenge (SMS)", + validChallenge: true, + validCode: false, + factorType: models.Phone, + expectedHTTPCode: http.StatusUnprocessableEntity, + }, + { + desc: "Valid /verify request (SMS)", + validChallenge: true, + validCode: true, + factorType: models.Phone, + expectedHTTPCode: http.StatusOK, + }, + } + for _, v := range cases { + ts.Run(v.desc, func() { + // Authenticate users and set secret + var buffer bytes.Buffer + r, err := models.GrantAuthenticatedUser(ts.API.db, ts.TestUser, models.GrantParams{}) + require.NoError(ts.T(), err) + token := ts.generateAAL1Token(ts.TestUser, r.SessionId) + var f *models.Factor + var sharedSecret string + + if v.factorType == models.TOTP { + friendlyName := uuid.Must(uuid.NewV4()).String() + f = models.NewTOTPFactor(ts.TestUser, friendlyName) + sharedSecret = ts.TestOTPKey.Secret() + f.Secret = sharedSecret + require.NoError(ts.T(), ts.API.db.Create(f), "Error updating new test factor") + } else if v.factorType == models.Phone { + friendlyName := uuid.Must(uuid.NewV4()).String() + numDigits := 10 + otp := crypto.GenerateOtp(numDigits) + require.NoError(ts.T(), err) + phone := fmt.Sprintf("+%s", otp) + f = models.NewPhoneFactor(ts.TestUser, phone, friendlyName) + require.NoError(ts.T(), ts.API.db.Create(f), "Error creating new SMS factor") + } + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/factors/%s/verify", f.ID), &buffer) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + var c *models.Challenge + var code string + if v.factorType == models.TOTP { + c = f.CreateChallenge(utilities.GetIPAddress(req)) + // Verify TOTP code + code, err = totp.GenerateCode(sharedSecret, time.Now().UTC()) + require.NoError(ts.T(), err) + } else if v.factorType == models.Phone { + code = "123456" + c, err = f.CreatePhoneChallenge(utilities.GetIPAddress(req), code, ts.Config.Security.DBEncryption.Encrypt, ts.Config.Security.DBEncryption.EncryptionKeyID, ts.Config.Security.DBEncryption.EncryptionKey) + require.NoError(ts.T(), err) + } + + if !v.validCode && v.factorType == models.TOTP { + code, err = totp.GenerateCode(sharedSecret, time.Now().UTC().Add(-1*time.Minute*time.Duration(1))) + require.NoError(ts.T(), err) + + } else if !v.validCode && v.factorType == models.Phone { + invalidSuffix := "1" + code += invalidSuffix + } + + require.NoError(ts.T(), ts.API.db.Create(c), "Error saving new test challenge") + if !v.validChallenge { + // Set challenge creation so that it has expired in present time. + newCreatedAt := time.Now().UTC().Add(-1 * time.Second * time.Duration(ts.Config.MFA.ChallengeExpiryDuration+1)) + // created_at is managed by buffalo(ORM) needs to be raw query to be updated + err := ts.API.db.RawQuery("UPDATE auth.mfa_challenges SET created_at = ? WHERE factor_id = ?", newCreatedAt, f.ID).Exec() + require.NoError(ts.T(), err, "Error updating new test challenge") + } + + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "challenge_id": c.ID, + "code": code, + })) + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), v.expectedHTTPCode, w.Code) + + if v.expectedHTTPCode == http.StatusOK { + // Ensure alternate session has been deleted + _, err = models.FindSessionByID(ts.API.db, ts.TestSecondarySession.ID, false) + require.EqualError(ts.T(), err, models.SessionNotFoundError{}.Error()) + } + if !v.validChallenge { + // Ensure invalid challenges are deleted + _, err := f.FindChallengeByID(ts.API.db, c.ID) + require.EqualError(ts.T(), err, models.ChallengeNotFoundError{}.Error()) + } + }) + } +} + +func (ts *MFATestSuite) TestUnenrollVerifiedFactor() { + cases := []struct { + desc string + isAAL2 bool + expectedHTTPCode int + }{ + { + desc: "Verified Factor: AAL1", + isAAL2: false, + expectedHTTPCode: http.StatusUnprocessableEntity, + }, + { + desc: "Verified Factor: AAL2, Success", + isAAL2: true, + expectedHTTPCode: http.StatusOK, + }, + } + for _, v := range cases { + ts.Run(v.desc, func() { + var buffer bytes.Buffer + + // Create Session to test behaviour which downgrades other sessions + f := ts.TestUser.Factors[0] + require.NoError(ts.T(), f.UpdateStatus(ts.API.db, models.FactorStateVerified)) + if v.isAAL2 { + ts.TestSession.UpdateAALAndAssociatedFactor(ts.API.db, models.AAL2, &f.ID) + } + token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) + w := ServeAuthenticatedRequest(ts, http.MethodDelete, fmt.Sprintf("/factors/%s", f.ID), token, buffer) + require.Equal(ts.T(), v.expectedHTTPCode, w.Code) + + if v.expectedHTTPCode == http.StatusOK { + _, err := models.FindFactorByFactorID(ts.API.db, f.ID) + require.EqualError(ts.T(), err, models.FactorNotFoundError{}.Error()) + session, _ := models.FindSessionByID(ts.API.db, ts.TestSecondarySession.ID, false) + require.Equal(ts.T(), models.AAL1.String(), session.GetAAL()) + require.Nil(ts.T(), session.FactorID) + + } + }) + } + +} + +func (ts *MFATestSuite) TestUnenrollUnverifiedFactor() { + var buffer bytes.Buffer + f := ts.TestUser.Factors[0] + + token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "factor_id": f.ID, + })) + + w := ServeAuthenticatedRequest(ts, http.MethodDelete, fmt.Sprintf("/factors/%s", f.ID), token, buffer) + require.Equal(ts.T(), http.StatusOK, w.Code) + + _, err := models.FindFactorByFactorID(ts.API.db, f.ID) + require.EqualError(ts.T(), err, models.FactorNotFoundError{}.Error()) + session, _ := models.FindSessionByID(ts.API.db, ts.TestSecondarySession.ID, false) + require.Equal(ts.T(), models.AAL1.String(), session.GetAAL()) + require.Nil(ts.T(), session.FactorID) + +} + +// Integration Tests +func (ts *MFATestSuite) TestSessionsMaintainAALOnRefresh() { + ts.Config.Security.RefreshTokenRotationEnabled = true + resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword, true /* <- requireStatusOK */) + accessTokenResp := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp)) + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": accessTokenResp.RefreshToken, + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + ctx, err := ts.API.parseJWTClaims(data.Token, req) + require.NoError(ts.T(), err) + ctx, err = ts.API.maybeLoadUserOrSession(ctx) + require.NoError(ts.T(), err) + require.True(ts.T(), getSession(ctx).IsAAL2()) +} + +// Performing MFA Verification followed by a sign in should return an AAL1 session and an AAL2 session +func (ts *MFATestSuite) TestMFAFollowedByPasswordSignIn() { + ts.Config.Security.RefreshTokenRotationEnabled = true + resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword, true /* <- requireStatusOK */) + accessTokenResp := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp)) + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": ts.TestEmail, + "password": ts.TestPassword, + })) + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + ctx, err := ts.API.parseJWTClaims(data.Token, req) + require.NoError(ts.T(), err) + + ctx, err = ts.API.maybeLoadUserOrSession(ctx) + require.NoError(ts.T(), err) + + require.Equal(ts.T(), models.AAL1.String(), getSession(ctx).GetAAL()) + session, err := models.FindSessionByUserID(ts.API.db, accessTokenResp.User.ID) + require.NoError(ts.T(), err) + require.True(ts.T(), session.IsAAL2()) +} + +func (ts *MFATestSuite) TestChallengeWebAuthnFactor() { + factor := models.NewWebAuthnFactor(ts.TestUser, "WebAuthnfactor") + validWebAuthnConfiguration := &WebAuthnParams{ + RPID: "localhost", + RPOrigins: "http://localhost:3000", + } + require.NoError(ts.T(), ts.API.db.Create(factor), "Error saving new test factor") + token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) + w := performChallengeWebAuthnFlow(ts, factor.ID, token, validWebAuthnConfiguration) + require.Equal(ts.T(), http.StatusOK, w.Code) +} + +func performChallengeWebAuthnFlow(ts *MFATestSuite, factorID uuid.UUID, token string, webauthn *WebAuthnParams) *httptest.ResponseRecorder { + var buffer bytes.Buffer + err := json.NewEncoder(&buffer).Encode(ChallengeFactorParams{WebAuthn: webauthn}) + require.NoError(ts.T(), err) + w := ServeAuthenticatedRequest(ts, http.MethodPost, fmt.Sprintf("http://localhost/factors/%s/challenge", factorID), token, buffer) + require.Equal(ts.T(), http.StatusOK, w.Code) + return w +} + +func (ts *MFATestSuite) TestChallengeFactorNotOwnedByUser() { + var buffer bytes.Buffer + email := "nomfaenabled@test.com" + password := "testpassword" + signUpResp := signUp(ts, email, password) + + friendlyName := "testfactor" + phoneNumber := "+1234567" + + otherUsersPhoneFactor := models.NewPhoneFactor(ts.TestUser, phoneNumber, friendlyName) + require.NoError(ts.T(), ts.API.db.Create(otherUsersPhoneFactor), "Error creating factor") + + w := ServeAuthenticatedRequest(ts, http.MethodPost, fmt.Sprintf("http://localhost/factors/%s/challenge", otherUsersPhoneFactor.ID), signUpResp.Token, buffer) + + expectedError := notFoundError(ErrorCodeMFAFactorNotFound, "Factor not found") + + var data HTTPError + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + require.Equal(ts.T(), expectedError.ErrorCode, data.ErrorCode) + require.Equal(ts.T(), http.StatusNotFound, w.Code) + +} + +func signUp(ts *MFATestSuite, email, password string) (signUpResp AccessTokenResponse) { + ts.API.config.Mailer.Autoconfirm = true + var buffer bytes.Buffer + + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": email, + "password": password, + })) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/signup", &buffer) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + data := AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + return data +} + +func performTestSignupAndVerify(ts *MFATestSuite, email, password string, requireStatusOK bool) *httptest.ResponseRecorder { + signUpResp := signUp(ts, email, password) + resp := performEnrollAndVerify(ts, signUpResp.Token, requireStatusOK) + + return resp + +} + +func performEnrollFlow(ts *MFATestSuite, token, friendlyName, factorType, issuer string, phone string, expectedCode int) *httptest.ResponseRecorder { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(EnrollFactorParams{FriendlyName: friendlyName, FactorType: factorType, Issuer: issuer, Phone: phone})) + w := ServeAuthenticatedRequest(ts, http.MethodPost, "http://localhost/factors/", token, buffer) + require.Equal(ts.T(), expectedCode, w.Code) + return w +} + +func ServeAuthenticatedRequest(ts *MFATestSuite, method, path, token string, buffer bytes.Buffer) *httptest.ResponseRecorder { + w := httptest.NewRecorder() + req := httptest.NewRequest(method, path, &buffer) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + req.Header.Set("Content-Type", "application/json") + + ts.API.handler.ServeHTTP(w, req) + return w +} + +func performVerifyFlow(ts *MFATestSuite, challengeID, factorID uuid.UUID, token string, requireStatusOK bool) *httptest.ResponseRecorder { + var buffer bytes.Buffer + + factor, err := models.FindFactorByFactorID(ts.API.db, factorID) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), factor) + + totpSecret := factor.Secret + + if es := crypto.ParseEncryptedString(factor.Secret); es != nil { + secret, err := es.Decrypt(factor.ID.String(), ts.API.config.Security.DBEncryption.DecryptionKeys) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), secret) + + totpSecret = string(secret) + } + + code, err := totp.GenerateCode(totpSecret, time.Now().UTC()) + require.NoError(ts.T(), err) + + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "challenge_id": challengeID, + "code": code, + })) + + y := ServeAuthenticatedRequest(ts, http.MethodPost, fmt.Sprintf("/factors/%s/verify", factorID), token, buffer) + + if requireStatusOK { + require.Equal(ts.T(), http.StatusOK, y.Code) + } + return y +} + +func performChallengeFlow(ts *MFATestSuite, factorID uuid.UUID, token string) *httptest.ResponseRecorder { + var buffer bytes.Buffer + w := ServeAuthenticatedRequest(ts, http.MethodPost, fmt.Sprintf("http://localhost/factors/%s/challenge", factorID), token, buffer) + require.Equal(ts.T(), http.StatusOK, w.Code) + return w + +} + +func performSMSChallengeFlow(ts *MFATestSuite, factorID uuid.UUID, token, channel string) *httptest.ResponseRecorder { + params := ChallengeFactorParams{ + Channel: channel, + } + var buffer bytes.Buffer + if err := json.NewEncoder(&buffer).Encode(params); err != nil { + panic(err) // handle the error appropriately in real code + } + + w := ServeAuthenticatedRequest(ts, http.MethodPost, fmt.Sprintf("http://localhost/factors/%s/challenge", factorID), token, buffer) + require.Equal(ts.T(), http.StatusOK, w.Code) + return w + +} + +func performEnrollAndVerify(ts *MFATestSuite, token string, requireStatusOK bool) *httptest.ResponseRecorder { + w := performEnrollFlow(ts, token, "", models.TOTP, ts.TestDomain, "", http.StatusOK) + enrollResp := EnrollFactorResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&enrollResp)) + factorID := enrollResp.ID + + // Challenge + w = performChallengeFlow(ts, factorID, token) + + challengeResp := EnrollFactorResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&challengeResp)) + challengeID := challengeResp.ID + + // Verify + y := performVerifyFlow(ts, challengeID, factorID, token, requireStatusOK) + + return y +} + +func (ts *MFATestSuite) TestVerificationHooks() { + type verificationHookTestCase struct { + desc string + enabled bool + uri string + hookFunctionSQL string + emailSuffix string + expectToken bool + expectedCode int + cleanupHookFunction string + } + cases := []verificationHookTestCase{ + { + desc: "Default Success", + enabled: true, + uri: "pg-functions://postgres/auth/verification_hook", + hookFunctionSQL: ` + create or replace function verification_hook(input jsonb) + returns json as $$ + begin + return json_build_object('decision', 'continue'); + end; $$ language plpgsql;`, + emailSuffix: "success", + expectToken: true, + expectedCode: http.StatusOK, + cleanupHookFunction: "verification_hook(input jsonb)", + }, + { + desc: "Error", + enabled: true, + uri: "pg-functions://postgres/auth/test_verification_hook_error", + hookFunctionSQL: ` + create or replace function test_verification_hook_error(input jsonb) + returns json as $$ + begin + RAISE EXCEPTION 'Intentional Error for Testing'; + end; $$ language plpgsql;`, + emailSuffix: "error", + expectToken: false, + expectedCode: http.StatusInternalServerError, + cleanupHookFunction: "test_verification_hook_error(input jsonb)", + }, + { + desc: "Reject - Enabled", + enabled: true, + uri: "pg-functions://postgres/auth/verification_hook_reject", + hookFunctionSQL: ` + create or replace function verification_hook_reject(input jsonb) + returns json as $$ + begin + return json_build_object( + 'decision', 'reject', + 'message', 'authentication attempt rejected' + ); + end; $$ language plpgsql;`, + emailSuffix: "reject_enabled", + expectToken: false, + expectedCode: http.StatusForbidden, + cleanupHookFunction: "verification_hook_reject(input jsonb)", + }, + { + desc: "Reject - Disabled", + enabled: false, + uri: "pg-functions://postgres/auth/verification_hook_reject", + hookFunctionSQL: ` + create or replace function verification_hook_reject(input jsonb) + returns json as $$ + begin + return json_build_object( + 'decision', 'reject', + 'message', 'authentication attempt rejected' + ); + end; $$ language plpgsql;`, + emailSuffix: "reject_disabled", + expectToken: true, + expectedCode: http.StatusOK, + cleanupHookFunction: "verification_hook_reject(input jsonb)", + }, + { + desc: "Timeout", + enabled: true, + uri: "pg-functions://postgres/auth/test_verification_hook_timeout", + hookFunctionSQL: ` + create or replace function test_verification_hook_timeout(input jsonb) + returns json as $$ + begin + PERFORM pg_sleep(3); + return json_build_object( + 'decision', 'continue' + ); + end; $$ language plpgsql;`, + emailSuffix: "timeout", + expectToken: false, + expectedCode: http.StatusInternalServerError, + cleanupHookFunction: "test_verification_hook_timeout(input jsonb)", + }, + } + + for _, c := range cases { + ts.T().Run(c.desc, func(t *testing.T) { + ts.Config.Hook.MFAVerificationAttempt.Enabled = c.enabled + ts.Config.Hook.MFAVerificationAttempt.URI = c.uri + require.NoError(ts.T(), ts.Config.Hook.MFAVerificationAttempt.PopulateExtensibilityPoint()) + + err := ts.API.db.RawQuery(c.hookFunctionSQL).Exec() + require.NoError(t, err) + + email := fmt.Sprintf("testemail_%s@gmail.com", c.emailSuffix) + password := "testpassword" + resp := performTestSignupAndVerify(ts, email, password, c.expectToken) + require.Equal(ts.T(), c.expectedCode, resp.Code) + accessTokenResp := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp)) + + if c.expectToken { + require.NotEqual(t, "", accessTokenResp.Token) + } else { + require.Equal(t, "", accessTokenResp.Token) + } + + cleanupHook(ts, c.cleanupHookFunction) + }) + } +} + +func cleanupHook(ts *MFATestSuite, hookName string) { + cleanupHookSQL := fmt.Sprintf("drop function if exists %s", hookName) + err := ts.API.db.RawQuery(cleanupHookSQL).Exec() + require.NoError(ts.T(), err) +} diff --git a/auth_v2.169.0/internal/api/middleware.go b/auth_v2.169.0/internal/api/middleware.go new file mode 100644 index 0000000..d387936 --- /dev/null +++ b/auth_v2.169.0/internal/api/middleware.go @@ -0,0 +1,401 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "sync" + "time" + + chimiddleware "github.com/go-chi/chi/v5/middleware" + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/security" + "github.com/supabase/auth/internal/utilities" + + "github.com/didip/tollbooth/v5" + "github.com/didip/tollbooth/v5/limiter" + jwt "github.com/golang-jwt/jwt/v5" +) + +type FunctionHooks map[string][]string + +type AuthMicroserviceClaims struct { + jwt.RegisteredClaims + SiteURL string `json:"site_url"` + InstanceID string `json:"id"` + FunctionHooks FunctionHooks `json:"function_hooks"` +} + +func (f *FunctionHooks) UnmarshalJSON(b []byte) error { + var raw map[string][]string + err := json.Unmarshal(b, &raw) + if err == nil { + *f = FunctionHooks(raw) + return nil + } + // If unmarshaling into map[string][]string fails, try legacy format. + var legacy map[string]string + err = json.Unmarshal(b, &legacy) + if err != nil { + return err + } + if *f == nil { + *f = make(FunctionHooks) + } + for event, hook := range legacy { + (*f)[event] = []string{hook} + } + return nil +} + +var emailRateLimitCounter = observability.ObtainMetricCounter("gotrue_email_rate_limit_counter", "Number of times an email rate limit has been triggered") + +func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler { + return func(w http.ResponseWriter, req *http.Request) (context.Context, error) { + c := req.Context() + + if limitHeader := a.config.RateLimitHeader; limitHeader != "" { + key := req.Header.Get(limitHeader) + + if key == "" { + log := observability.GetLogEntry(req).Entry + log.WithField("header", limitHeader).Warn("request does not have a value for the rate limiting header, rate limiting is not applied") + return c, nil + } else { + err := tollbooth.LimitByKeys(lmt, []string{key}) + if err != nil { + return c, tooManyRequestsError(ErrorCodeOverRequestRateLimit, "Request rate limit reached") + } + } + } + return c, nil + } +} + +func (a *API) requireAdminCredentials(w http.ResponseWriter, req *http.Request) (context.Context, error) { + t, err := a.extractBearerToken(req) + if err != nil || t == "" { + return nil, err + } + + ctx, err := a.parseJWTClaims(t, req) + if err != nil { + return nil, err + } + + return a.requireAdmin(ctx) +} + +func (a *API) requireEmailProvider(w http.ResponseWriter, req *http.Request) (context.Context, error) { + ctx := req.Context() + config := a.config + + if !config.External.Email.Enabled { + return nil, badRequestError(ErrorCodeEmailProviderDisabled, "Email logins are disabled") + } + + return ctx, nil +} + +func (a *API) verifyCaptcha(w http.ResponseWriter, req *http.Request) (context.Context, error) { + ctx := req.Context() + config := a.config + + if !config.Security.Captcha.Enabled { + return ctx, nil + } + if _, err := a.requireAdminCredentials(w, req); err == nil { + // skip captcha validation if authorization header contains an admin role + return ctx, nil + } + if shouldIgnore := isIgnoreCaptchaRoute(req); shouldIgnore { + return ctx, nil + } + + body := &security.GotrueRequest{} + if err := retrieveRequestParams(req, body); err != nil { + return nil, err + } + + verificationResult, err := security.VerifyRequest(body, utilities.GetIPAddress(req), strings.TrimSpace(config.Security.Captcha.Secret), config.Security.Captcha.Provider) + if err != nil { + return nil, internalServerError("captcha verification process failed").WithInternalError(err) + } + + if !verificationResult.Success { + return nil, badRequestError(ErrorCodeCaptchaFailed, "captcha protection: request disallowed (%s)", strings.Join(verificationResult.ErrorCodes, ", ")) + } + + return ctx, nil +} + +func isIgnoreCaptchaRoute(req *http.Request) bool { + // captcha shouldn't be enabled on the following grant_types + // id_token, refresh_token, pkce + if req.URL.Path == "/token" && req.FormValue("grant_type") != "password" { + return true + } + return false +} + +func (a *API) isValidExternalHost(w http.ResponseWriter, req *http.Request) (context.Context, error) { + ctx := req.Context() + config := a.config + + xForwardedHost := req.Header.Get("X-Forwarded-Host") + xForwardedProto := req.Header.Get("X-Forwarded-Proto") + reqHost := req.URL.Hostname() + + if len(config.Mailer.ExternalHosts) > 0 { + // this server is configured to accept multiple external hosts, validate the host from the X-Forwarded-Host or Host headers + + hostname := "" + protocol := "https" + + if xForwardedHost != "" { + for _, host := range config.Mailer.ExternalHosts { + if host == xForwardedHost { + hostname = host + break + } + } + } else if reqHost != "" { + for _, host := range config.Mailer.ExternalHosts { + if host == reqHost { + hostname = host + break + } + } + } + + if hostname != "" { + if hostname == "localhost" { + // allow the use of HTTP only if the accepted hostname was localhost + if xForwardedProto == "http" || req.URL.Scheme == "http" { + protocol = "http" + } + } + + externalHostURL, err := url.ParseRequestURI(fmt.Sprintf("%s://%s", protocol, hostname)) + if err != nil { + return ctx, err + } + + return withExternalHost(ctx, externalHostURL), nil + } + } + + if xForwardedHost != "" || reqHost != "" { + // host has been provided to the request, but it hasn't been + // added to the allow list, raise a log message + // in Supabase platform the X-Forwarded-Host and full request + // URL are likely sanitzied before they reach the server + + fields := make(logrus.Fields) + + if xForwardedHost != "" { + fields["x_forwarded_host"] = xForwardedHost + } + + if xForwardedProto != "" { + fields["x_forwarded_proto"] = xForwardedProto + } + + if reqHost != "" { + fields["request_url_host"] = reqHost + + if req.URL.Scheme != "" { + fields["request_url_scheme"] = req.URL.Scheme + } + } + + logrus.WithFields(fields).Info("Request received external host in X-Forwarded-Host or Host headers, but the values have not been added to GOTRUE_MAILER_EXTERNAL_HOSTS and will not be used. To suppress this message add the host, or sanitize the headers before the request reaches Auth.") + } + + // either the provided external hosts don't match the allow list, or + // the server is not configured to accept multiple hosts -- use the + // configured external URL instead + + externalHostURL, err := url.ParseRequestURI(config.API.ExternalURL) + if err != nil { + return ctx, err + } + + return withExternalHost(ctx, externalHostURL), nil +} + +func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { + ctx := req.Context() + if !a.config.SAML.Enabled { + return nil, notFoundError(ErrorCodeSAMLProviderDisabled, "SAML 2.0 is disabled") + } + return ctx, nil +} + +func (a *API) requireManualLinkingEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { + ctx := req.Context() + if !a.config.Security.ManualLinkingEnabled { + return nil, notFoundError(ErrorCodeManualLinkingDisabled, "Manual linking is disabled") + } + return ctx, nil +} + +func (a *API) databaseCleanup(cleanup models.Cleaner) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wrappedResp := chimiddleware.NewWrapResponseWriter(w, r.ProtoMajor) + next.ServeHTTP(wrappedResp, r) + switch r.Method { + case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete: + if (wrappedResp.Status() / 100) != 2 { + // don't do any cleanups for non-2xx responses + return + } + // continue + default: + return + } + + db := a.db.WithContext(r.Context()) + log := observability.GetLogEntry(r).Entry + + affectedRows, err := cleanup.Clean(db) + if err != nil { + log.WithError(err).WithField("affected_rows", affectedRows).Warn("database cleanup failed") + } else if affectedRows > 0 { + log.WithField("affected_rows", affectedRows).Debug("cleaned up expired or stale rows") + } + }) + } +} + +// timeoutResponseWriter is a http.ResponseWriter that queues up a response +// body to be sent if the serving completes before the context has exceeded its +// deadline. +type timeoutResponseWriter struct { + sync.Mutex + + header http.Header + wroteHeader bool + snapHeader http.Header // snapshot of the header at the time WriteHeader was called + statusCode int + buf bytes.Buffer +} + +func (t *timeoutResponseWriter) Header() http.Header { + t.Lock() + defer t.Unlock() + + return t.header +} + +func (t *timeoutResponseWriter) Write(bytes []byte) (int, error) { + t.Lock() + defer t.Unlock() + + if !t.wroteHeader { + t.writeHeaderLocked(http.StatusOK) + } + + return t.buf.Write(bytes) +} + +func (t *timeoutResponseWriter) WriteHeader(statusCode int) { + t.Lock() + defer t.Unlock() + + t.writeHeaderLocked(statusCode) +} + +func (t *timeoutResponseWriter) writeHeaderLocked(statusCode int) { + if t.wroteHeader { + // ignore multiple calls to WriteHeader + // once WriteHeader has been called once, a snapshot of the header map is taken + // and saved in snapHeader to be used in finallyWrite + return + } + + t.statusCode = statusCode + t.wroteHeader = true + t.snapHeader = t.header.Clone() +} + +func (t *timeoutResponseWriter) finallyWrite(w http.ResponseWriter) { + t.Lock() + defer t.Unlock() + + dst := w.Header() + for k, vv := range t.snapHeader { + dst[k] = vv + } + + if !t.wroteHeader { + t.statusCode = http.StatusOK + } + + w.WriteHeader(t.statusCode) + if _, err := w.Write(t.buf.Bytes()); err != nil { + logrus.WithError(err).Warn("Write failed") + } +} + +func timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), timeout) + defer cancel() + + timeoutWriter := &timeoutResponseWriter{ + header: make(http.Header), + } + + panicChan := make(chan any, 1) + serverDone := make(chan struct{}) + go func() { + defer func() { + if p := recover(); p != nil { + panicChan <- p + } + }() + + next.ServeHTTP(timeoutWriter, r.WithContext(ctx)) + close(serverDone) + }() + + select { + case p := <-panicChan: + panic(p) + + case <-serverDone: + timeoutWriter.finallyWrite(w) + + case <-ctx.Done(): + err := ctx.Err() + + if err == context.DeadlineExceeded { + httpError := &HTTPError{ + HTTPStatus: http.StatusGatewayTimeout, + ErrorCode: ErrorCodeRequestTimeout, + Message: "Processing this request timed out, please retry after a moment.", + } + + httpError = httpError.WithInternalError(err) + + HandleResponseError(httpError, w, r) + } else { + // unrecognized context error, so we should wait for the server to finish + // and write out the response + <-serverDone + + timeoutWriter.finallyWrite(w) + } + } + }) + } +} diff --git a/auth_v2.169.0/internal/api/middleware_test.go b/auth_v2.169.0/internal/api/middleware_test.go new file mode 100644 index 0000000..98dd6a8 --- /dev/null +++ b/auth_v2.169.0/internal/api/middleware_test.go @@ -0,0 +1,510 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/didip/tollbooth/v5" + "github.com/didip/tollbooth/v5/limiter" + jwt "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" +) + +const ( + HCaptchaSecret string = "0x0000000000000000000000000000000000000000" + CaptchaResponse string = "10000000-aaaa-bbbb-cccc-000000000001" + TurnstileCaptchaSecret string = "1x0000000000000000000000000000000AA" +) + +type MiddlewareTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestMiddlewareFunctions(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &MiddlewareTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *MiddlewareTestSuite) TestVerifyCaptchaValid() { + ts.Config.Security.Captcha.Enabled = true + + adminClaims := &AccessTokenClaims{ + Role: "supabase_admin", + } + adminJwt, err := jwt.NewWithClaims(jwt.SigningMethodHS256, adminClaims).SignedString([]byte(ts.Config.JWT.Secret)) + require.NoError(ts.T(), err) + cases := []struct { + desc string + adminJwt string + captcha_token string + captcha_provider string + }{ + { + "Valid captcha response", + "", + CaptchaResponse, + "hcaptcha", + }, + { + "Valid captcha response", + "", + CaptchaResponse, + "turnstile", + }, + { + "Ignore captcha if admin role is present", + adminJwt, + "", + "hcaptcha", + }, + { + "Ignore captcha if admin role is present", + adminJwt, + "", + "turnstile", + }, + } + for _, c := range cases { + ts.Config.Security.Captcha.Provider = c.captcha_provider + if c.captcha_provider == "turnstile" { + ts.Config.Security.Captcha.Secret = TurnstileCaptchaSecret + } else if c.captcha_provider == "hcaptcha" { + ts.Config.Security.Captcha.Secret = HCaptchaSecret + } + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + "password": "secret", + "gotrue_meta_security": map[string]interface{}{ + "captcha_token": c.captcha_token, + }, + })) + req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer) + req.Header.Set("Content-Type", "application/json") + if c.adminJwt != "" { + req.Header.Set("Authorization", "Bearer "+c.adminJwt) + } + + beforeCtx := context.Background() + req = req.WithContext(beforeCtx) + + w := httptest.NewRecorder() + + afterCtx, err := ts.API.verifyCaptcha(w, req) + require.NoError(ts.T(), err) + + body, err := io.ReadAll(req.Body) + require.NoError(ts.T(), err) + + // re-initialize buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + "password": "secret", + "gotrue_meta_security": map[string]interface{}{ + "captcha_token": c.captcha_token, + }, + })) + + // check if body is the same + require.Equal(ts.T(), body, buffer.Bytes()) + require.Equal(ts.T(), afterCtx, beforeCtx) + } +} + +func (ts *MiddlewareTestSuite) TestVerifyCaptchaInvalid() { + cases := []struct { + desc string + captchaConf *conf.CaptchaConfiguration + expectedCode int + expectedMsg string + }{ + { + "Captcha validation failed", + &conf.CaptchaConfiguration{ + Enabled: true, + Provider: "hcaptcha", + Secret: "test", + }, + http.StatusBadRequest, + "captcha protection: request disallowed (not-using-dummy-secret)", + }, + { + "Captcha validation failed", + &conf.CaptchaConfiguration{ + Enabled: true, + Provider: "turnstile", + Secret: "anothertest", + }, + http.StatusBadRequest, + "captcha protection: request disallowed (invalid-input-secret)", + }, + } + for _, c := range cases { + ts.Run(c.desc, func() { + ts.Config.Security.Captcha = *c.captchaConf + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + "password": "secret", + "gotrue_meta_security": map[string]interface{}{ + "captcha_token": CaptchaResponse, + }, + })) + req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer) + req.Header.Set("Content-Type", "application/json") + + req = req.WithContext(context.Background()) + + w := httptest.NewRecorder() + + _, err := ts.API.verifyCaptcha(w, req) + require.Equal(ts.T(), c.expectedCode, err.(*HTTPError).HTTPStatus) + require.Equal(ts.T(), c.expectedMsg, err.(*HTTPError).Message) + }) + } +} + +func (ts *MiddlewareTestSuite) TestIsValidExternalHost() { + cases := []struct { + desc string + externalHosts []string + + requestURL string + headers http.Header + + expectedURL string + }{ + { + desc: "no defined external hosts, no headers, no absolute request URL", + requestURL: "/some-path", + expectedURL: ts.API.config.API.ExternalURL, + }, + + { + desc: "no defined external hosts, unauthorized X-Forwarded-Host without any external hosts", + headers: http.Header{ + "X-Forwarded-Host": []string{ + "external-host.com", + }, + }, + requestURL: "/some-path", + expectedURL: ts.API.config.API.ExternalURL, + }, + + { + desc: "defined external hosts, unauthorized X-Forwarded-Host", + externalHosts: []string{"authorized-host.com"}, + headers: http.Header{ + "X-Forwarded-Proto": []string{"https"}, + "X-Forwarded-Host": []string{ + "external-host.com", + }, + }, + requestURL: "/some-path", + expectedURL: ts.API.config.API.ExternalURL, + }, + + { + desc: "no defined external hosts, unauthorized Host", + requestURL: "https://external-host.com/some-path", + expectedURL: ts.API.config.API.ExternalURL, + }, + + { + desc: "defined external hosts, unauthorized Host", + externalHosts: []string{"authorized-host.com"}, + requestURL: "https://external-host.com/some-path", + expectedURL: ts.API.config.API.ExternalURL, + }, + + { + desc: "defined external hosts, authorized X-Forwarded-Host", + externalHosts: []string{"authorized-host.com"}, + headers: http.Header{ + "X-Forwarded-Proto": []string{"http"}, // this should be ignored and default to HTTPS + "X-Forwarded-Host": []string{ + "authorized-host.com", + }, + }, + requestURL: "https://X-Forwarded-Host-takes-precedence.com/some-path", + expectedURL: "https://authorized-host.com", + }, + + { + desc: "defined external hosts, authorized Host", + externalHosts: []string{"authorized-host.com"}, + requestURL: "https://authorized-host.com/some-path", + expectedURL: "https://authorized-host.com", + }, + + { + desc: "defined external hosts, authorized X-Forwarded-Host", + externalHosts: []string{"authorized-host.com"}, + headers: http.Header{ + "X-Forwarded-Proto": []string{"http"}, // this should be ignored and default to HTTPS + "X-Forwarded-Host": []string{ + "authorized-host.com", + }, + }, + requestURL: "https://X-Forwarded-Host-takes-precedence.com/some-path", + expectedURL: "https://authorized-host.com", + }, + + { + desc: "defined external hosts, authorized localhost in X-Forwarded-Host with HTTP", + externalHosts: []string{"localhost"}, + headers: http.Header{ + "X-Forwarded-Proto": []string{"http"}, + "X-Forwarded-Host": []string{ + "localhost", + }, + }, + requestURL: "/some-path", + expectedURL: "http://localhost", + }, + + { + desc: "defined external hosts, authorized localhost in Host with HTTP", + externalHosts: []string{"localhost"}, + requestURL: "http://localhost:3000/some-path", + expectedURL: "http://localhost", + }, + } + + require.NotEmpty(ts.T(), ts.API.config.API.ExternalURL) + + for _, c := range cases { + ts.Run(c.desc, func() { + req := httptest.NewRequest(http.MethodPost, c.requestURL, nil) + if c.headers != nil { + req.Header = c.headers + } + + originalHosts := ts.API.config.Mailer.ExternalHosts + ts.API.config.Mailer.ExternalHosts = c.externalHosts + + w := httptest.NewRecorder() + ctx, err := ts.API.isValidExternalHost(w, req) + + ts.API.config.Mailer.ExternalHosts = originalHosts + + require.NoError(ts.T(), err) + + externalURL := getExternalHost(ctx) + require.Equal(ts.T(), c.expectedURL, externalURL.String()) + }) + } +} + +func (ts *MiddlewareTestSuite) TestRequireSAMLEnabled() { + cases := []struct { + desc string + isEnabled bool + expectedErr error + }{ + { + desc: "SAML not enabled", + isEnabled: false, + expectedErr: notFoundError(ErrorCodeSAMLProviderDisabled, "SAML 2.0 is disabled"), + }, + { + desc: "SAML enabled", + isEnabled: true, + expectedErr: nil, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + ts.Config.SAML.Enabled = c.isEnabled + req := httptest.NewRequest("GET", "http://localhost", nil) + w := httptest.NewRecorder() + + _, err := ts.API.requireSAMLEnabled(w, req) + require.Equal(ts.T(), c.expectedErr, err) + }) + } +} + +func TestFunctionHooksUnmarshalJSON(t *testing.T) { + tests := []struct { + in string + ok bool + }{ + {`{ "signup" : "identity-signup" }`, true}, + {`{ "signup" : ["identity-signup"] }`, true}, + {`{ "signup" : {"foo" : "bar"} }`, false}, + } + for _, tt := range tests { + t.Run(tt.in, func(t *testing.T) { + var f FunctionHooks + err := json.Unmarshal([]byte(tt.in), &f) + if tt.ok { + assert.NoError(t, err) + assert.Equal(t, FunctionHooks{"signup": {"identity-signup"}}, f) + } else { + assert.Error(t, err) + } + }) + } +} + +func (ts *MiddlewareTestSuite) TestTimeoutMiddleware() { + ts.Config.API.MaxRequestDuration = 5 * time.Microsecond + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + w := httptest.NewRecorder() + + timeoutHandler := timeoutMiddleware(ts.Config.API.MaxRequestDuration) + + slowHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Sleep for 1 second to simulate a slow handler which should trigger the timeout + time.Sleep(1 * time.Second) + ts.API.handler.ServeHTTP(w, r) + }) + timeoutHandler(slowHandler).ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusGatewayTimeout, w.Code) + + var data map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.Equal(ts.T(), ErrorCodeRequestTimeout, data["error_code"]) + require.Equal(ts.T(), float64(504), data["code"]) + require.NotNil(ts.T(), data["msg"]) +} + +func TestTimeoutResponseWriter(t *testing.T) { + // timeoutResponseWriter should exhitbit a similar behavior as http.ResponseWriter + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + w1 := httptest.NewRecorder() + w2 := httptest.NewRecorder() + + timeoutHandler := timeoutMiddleware(time.Second * 10) + + redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // tries to redirect twice + http.Redirect(w, r, "http://localhost:3001/#message=first_message", http.StatusSeeOther) + + // overwrites the first + http.Redirect(w, r, "http://localhost:3001/second", http.StatusSeeOther) + }) + timeoutHandler(redirectHandler).ServeHTTP(w1, req) + redirectHandler.ServeHTTP(w2, req) + + require.Equal(t, w1.Result(), w2.Result()) +} + +func (ts *MiddlewareTestSuite) TestLimitHandler() { + ts.Config.RateLimitHeader = "X-Rate-Limit" + lmt := tollbooth.NewLimiter(5, &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }) + + okHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + b, _ := json.Marshal(map[string]interface{}{"message": "ok"}) + w.Write([]byte(b)) + }) + + for i := 0; i < 5; i++ { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0") + w := httptest.NewRecorder() + ts.API.limitHandler(lmt).handler(okHandler).ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + var data map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.Equal(ts.T(), "ok", data["message"]) + } + + // 6th request should fail and return a rate limit exceeded error + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0") + w := httptest.NewRecorder() + ts.API.limitHandler(lmt).handler(okHandler).ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusTooManyRequests, w.Code) +} + +type MockCleanup struct { + mock.Mock +} + +func (m *MockCleanup) Clean(db *storage.Connection) (int, error) { + m.Called(db) + return 0, nil +} + +func (ts *MiddlewareTestSuite) TestDatabaseCleanup() { + testHandler := func(statusCode int) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(statusCode) + b, _ := json.Marshal(map[string]interface{}{"message": "ok"}) + w.Write([]byte(b)) + }) + } + + cases := []struct { + desc string + statusCode int + method string + }{ + { + desc: "Run cleanup successfully", + statusCode: http.StatusOK, + method: http.MethodPost, + }, + { + desc: "Skip cleanup if GET", + statusCode: http.StatusOK, + method: http.MethodGet, + }, + { + desc: "Skip cleanup if 3xx", + statusCode: http.StatusSeeOther, + method: http.MethodPost, + }, + { + desc: "Skip cleanup if 4xx", + statusCode: http.StatusBadRequest, + method: http.MethodPost, + }, + { + desc: "Skip cleanup if 5xx", + statusCode: http.StatusInternalServerError, + method: http.MethodPost, + }, + } + + mockCleanup := new(MockCleanup) + mockCleanup.On("Clean", mock.Anything).Return(0, nil) + for _, c := range cases { + ts.Run("DatabaseCleanup", func() { + req := httptest.NewRequest(c.method, "http://localhost", nil) + w := httptest.NewRecorder() + ts.API.databaseCleanup(mockCleanup)(testHandler(c.statusCode)).ServeHTTP(w, req) + require.Equal(ts.T(), c.statusCode, w.Code) + }) + } + mockCleanup.AssertNumberOfCalls(ts.T(), "Clean", 1) +} diff --git a/auth_v2.169.0/internal/api/opentelemetry-tracer_test.go b/auth_v2.169.0/internal/api/opentelemetry-tracer_test.go new file mode 100644 index 0000000..4aeddce --- /dev/null +++ b/auth_v2.169.0/internal/api/opentelemetry-tracer_test.go @@ -0,0 +1,93 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" + semconv "go.opentelemetry.io/otel/semconv/v1.25.0" +) + +type OpenTelemetryTracerTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestOpenTelemetryTracer(t *testing.T) { + api, config, err := setupAPIForTestWithCallback(func(config *conf.GlobalConfiguration, conn *storage.Connection) { + if config != nil { + config.Tracing.Enabled = true + config.Tracing.Exporter = conf.OpenTelemetryTracing + } + }) + + require.NoError(t, err) + + ts := &OpenTelemetryTracerTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func getAttribute(attributes []attribute.KeyValue, key attribute.Key) *attribute.Value { + for _, value := range attributes { + if value.Key == key { + return &value.Value + } + } + + return nil +} + +func (ts *OpenTelemetryTracerTestSuite) TestOpenTelemetryTracer_Spans() { + exporter := tracetest.NewInMemoryExporter() + bsp := sdktrace.NewSimpleSpanProcessor(exporter) + traceProvider := sdktrace.NewTracerProvider( + sdktrace.WithSampler(sdktrace.AlwaysSample()), + sdktrace.WithSpanProcessor(bsp), + ) + otel.SetTracerProvider(traceProvider) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "http://localhost/something1", nil) + req.Header.Set("User-Agent", "whatever") + ts.API.handler.ServeHTTP(w, req) + + req = httptest.NewRequest(http.MethodGet, "http://localhost/something2", nil) + req.Header.Set("User-Agent", "whatever") + ts.API.handler.ServeHTTP(w, req) + + spanStubs := exporter.GetSpans() + spans := spanStubs.Snapshots() + + if assert.Equal(ts.T(), 2, len(spans)) { + attributes1 := spans[0].Attributes() + method1 := getAttribute(attributes1, semconv.HTTPMethodKey) + assert.Equal(ts.T(), "POST", method1.AsString()) + url1 := getAttribute(attributes1, semconv.HTTPTargetKey) + assert.Equal(ts.T(), "/something1", url1.AsString()) + statusCode1 := getAttribute(attributes1, semconv.HTTPStatusCodeKey) + assert.Equal(ts.T(), int64(404), statusCode1.AsInt64()) + + attributes2 := spans[1].Attributes() + method2 := getAttribute(attributes2, semconv.HTTPMethodKey) + assert.Equal(ts.T(), "GET", method2.AsString()) + url2 := getAttribute(attributes2, semconv.HTTPTargetKey) + assert.Equal(ts.T(), "/something2", url2.AsString()) + statusCode2 := getAttribute(attributes2, semconv.HTTPStatusCodeKey) + assert.Equal(ts.T(), int64(404), statusCode2.AsInt64()) + } +} diff --git a/auth_v2.169.0/internal/api/options.go b/auth_v2.169.0/internal/api/options.go new file mode 100644 index 0000000..9053c2f --- /dev/null +++ b/auth_v2.169.0/internal/api/options.go @@ -0,0 +1,102 @@ +package api + +import ( + "time" + + "github.com/didip/tollbooth/v5" + "github.com/didip/tollbooth/v5/limiter" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/ratelimit" +) + +type Option interface { + apply(*API) +} + +type LimiterOptions struct { + Email ratelimit.Limiter + Phone ratelimit.Limiter + + Signups *limiter.Limiter + AnonymousSignIns *limiter.Limiter + Recover *limiter.Limiter + Resend *limiter.Limiter + MagicLink *limiter.Limiter + Otp *limiter.Limiter + Token *limiter.Limiter + Verify *limiter.Limiter + User *limiter.Limiter + FactorVerify *limiter.Limiter + FactorChallenge *limiter.Limiter + SSO *limiter.Limiter + SAMLAssertion *limiter.Limiter +} + +func (lo *LimiterOptions) apply(a *API) { a.limiterOpts = lo } + +func NewLimiterOptions(gc *conf.GlobalConfiguration) *LimiterOptions { + o := &LimiterOptions{} + + o.Email = ratelimit.New(gc.RateLimitEmailSent) + o.Phone = ratelimit.New(gc.RateLimitSmsSent) + + o.AnonymousSignIns = tollbooth.NewLimiter(gc.RateLimitAnonymousUsers/(60*60), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(int(gc.RateLimitAnonymousUsers)).SetMethods([]string{"POST"}) + + o.Token = tollbooth.NewLimiter(gc.RateLimitTokenRefresh/(60*5), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30) + + o.Verify = tollbooth.NewLimiter(gc.RateLimitVerify/(60*5), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30) + + o.User = tollbooth.NewLimiter(gc.RateLimitOtp/(60*5), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30) + + o.FactorVerify = tollbooth.NewLimiter(gc.MFA.RateLimitChallengeAndVerify/60, + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Minute, + }).SetBurst(30) + + o.FactorChallenge = tollbooth.NewLimiter(gc.MFA.RateLimitChallengeAndVerify/60, + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Minute, + }).SetBurst(30) + + o.SSO = tollbooth.NewLimiter(gc.RateLimitSso/(60*5), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30) + + o.SAMLAssertion = tollbooth.NewLimiter(gc.SAML.RateLimitAssertion/(60*5), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30) + + o.Signups = tollbooth.NewLimiter(gc.RateLimitOtp/(60*5), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30) + + // These all use the OTP limit per 5 min with 1hour ttl and burst of 30. + o.Recover = newLimiterPer5mOver1h(gc.RateLimitOtp) + o.Resend = newLimiterPer5mOver1h(gc.RateLimitOtp) + o.MagicLink = newLimiterPer5mOver1h(gc.RateLimitOtp) + o.Otp = newLimiterPer5mOver1h(gc.RateLimitOtp) + return o +} + +func newLimiterPer5mOver1h(rate float64) *limiter.Limiter { + freq := rate / (60 * 5) + lim := tollbooth.NewLimiter(freq, &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30) + return lim +} diff --git a/auth_v2.169.0/internal/api/options_test.go b/auth_v2.169.0/internal/api/options_test.go new file mode 100644 index 0000000..c4c1d16 --- /dev/null +++ b/auth_v2.169.0/internal/api/options_test.go @@ -0,0 +1,30 @@ +package api + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/supabase/auth/internal/conf" +) + +func TestNewLimiterOptions(t *testing.T) { + cfg := &conf.GlobalConfiguration{} + cfg.ApplyDefaults() + + rl := NewLimiterOptions(cfg) + assert.NotNil(t, rl.Email) + assert.NotNil(t, rl.Phone) + assert.NotNil(t, rl.Signups) + assert.NotNil(t, rl.AnonymousSignIns) + assert.NotNil(t, rl.Recover) + assert.NotNil(t, rl.Resend) + assert.NotNil(t, rl.MagicLink) + assert.NotNil(t, rl.Otp) + assert.NotNil(t, rl.Token) + assert.NotNil(t, rl.Verify) + assert.NotNil(t, rl.User) + assert.NotNil(t, rl.FactorVerify) + assert.NotNil(t, rl.FactorChallenge) + assert.NotNil(t, rl.SSO) + assert.NotNil(t, rl.SAMLAssertion) +} diff --git a/auth_v2.169.0/internal/api/otp.go b/auth_v2.169.0/internal/api/otp.go new file mode 100644 index 0000000..1821da3 --- /dev/null +++ b/auth_v2.169.0/internal/api/otp.go @@ -0,0 +1,237 @@ +package api + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + + "github.com/sethvargo/go-password/password" + "github.com/supabase/auth/internal/api/sms_provider" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +// OtpParams contains the request body params for the otp endpoint +type OtpParams struct { + Email string `json:"email"` + Phone string `json:"phone"` + CreateUser bool `json:"create_user"` + Data map[string]interface{} `json:"data"` + Channel string `json:"channel"` + CodeChallengeMethod string `json:"code_challenge_method"` + CodeChallenge string `json:"code_challenge"` +} + +// SmsParams contains the request body params for sms otp +type SmsParams struct { + Phone string `json:"phone"` + Channel string `json:"channel"` + Data map[string]interface{} `json:"data"` + CodeChallengeMethod string `json:"code_challenge_method"` + CodeChallenge string `json:"code_challenge"` +} + +func (p *OtpParams) Validate() error { + if p.Email != "" && p.Phone != "" { + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided") + } + if p.Email != "" && p.Channel != "" { + return badRequestError(ErrorCodeValidationFailed, "Channel should only be specified with Phone OTP") + } + if err := validatePKCEParams(p.CodeChallengeMethod, p.CodeChallenge); err != nil { + return err + } + return nil +} + +func (p *SmsParams) Validate(config *conf.GlobalConfiguration) error { + var err error + p.Phone, err = validatePhone(p.Phone) + if err != nil { + return err + } + if !sms_provider.IsValidMessageChannel(p.Channel, config) { + return badRequestError(ErrorCodeValidationFailed, InvalidChannelError) + } + return nil +} + +// Otp returns the MagicLink or SmsOtp handler based on the request body params +func (a *API) Otp(w http.ResponseWriter, r *http.Request) error { + params := &OtpParams{ + CreateUser: true, + } + if params.Data == nil { + params.Data = make(map[string]interface{}) + } + + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + if err := params.Validate(); err != nil { + return err + } + if params.Data == nil { + params.Data = make(map[string]interface{}) + } + + if ok, err := a.shouldCreateUser(r, params); !ok { + return unprocessableEntityError(ErrorCodeOTPDisabled, "Signups not allowed for otp") + } else if err != nil { + return err + } + + if params.Email != "" { + return a.MagicLink(w, r) + } else if params.Phone != "" { + return a.SmsOtp(w, r) + } + + return badRequestError(ErrorCodeValidationFailed, "One of email or phone must be set") +} + +type SmsOtpResponse struct { + MessageID string `json:"message_id,omitempty"` +} + +// SmsOtp sends the user an otp via sms +func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + config := a.config + + if !config.External.Phone.Enabled { + return badRequestError(ErrorCodePhoneProviderDisabled, "Unsupported phone provider") + } + var err error + + params := &SmsParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + // For backwards compatibility, we default to SMS if params Channel is not specified + if params.Phone != "" && params.Channel == "" { + params.Channel = sms_provider.SMSProvider + } + + if err := params.Validate(config); err != nil { + return err + } + + var isNewUser bool + aud := a.requestAud(ctx, r) + user, err := models.FindUserByPhoneAndAudience(db, params.Phone, aud) + if err != nil { + if models.IsNotFoundError(err) { + isNewUser = true + } else { + return internalServerError("Database error finding user").WithInternalError(err) + } + } + if user != nil { + isNewUser = !user.IsPhoneConfirmed() + } + if isNewUser { + // User either doesn't exist or hasn't completed the signup process. + // Sign them up with temporary password. + password, err := password.Generate(64, 10, 1, false, true) + if err != nil { + return internalServerError("error creating user").WithInternalError(err) + } + + signUpParams := &SignupParams{ + Phone: params.Phone, + Password: password, + Data: params.Data, + Channel: params.Channel, + } + newBodyContent, err := json.Marshal(signUpParams) + if err != nil { + // SignupParams must be marshallable + panic(err) + } + r.Body = io.NopCloser(bytes.NewReader(newBodyContent)) + + fakeResponse := &responseStub{} + + if config.Sms.Autoconfirm { + // signups are autoconfirmed, send otp after signup + if err := a.Signup(fakeResponse, r); err != nil { + return err + } + + signUpParams := &SignupParams{ + Phone: params.Phone, + Channel: params.Channel, + } + newBodyContent, err := json.Marshal(signUpParams) + if err != nil { + // SignupParams must be marshallable + panic(err) + } + r.Body = io.NopCloser(bytes.NewReader(newBodyContent)) + return a.SmsOtp(w, r) + } + + if err := a.Signup(fakeResponse, r); err != nil { + return err + } + return sendJSON(w, http.StatusOK, make(map[string]string)) + } + + messageID := "" + err = db.Transaction(func(tx *storage.Connection) error { + if err := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", map[string]interface{}{ + "channel": params.Channel, + }); err != nil { + return err + } + mID, serr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneConfirmationOtp, params.Channel) + if serr != nil { + return serr + } + messageID = mID + return nil + }) + + if err != nil { + return err + } + + return sendJSON(w, http.StatusOK, SmsOtpResponse{ + MessageID: messageID, + }) +} + +func (a *API) shouldCreateUser(r *http.Request, params *OtpParams) (bool, error) { + ctx := r.Context() + db := a.db.WithContext(ctx) + + if !params.CreateUser { + ctx := r.Context() + aud := a.requestAud(ctx, r) + var err error + if params.Email != "" { + params.Email, err = a.validateEmail(params.Email) + if err != nil { + return false, err + } + _, err = models.FindUserByEmailAndAudience(db, params.Email, aud) + } else if params.Phone != "" { + params.Phone, err = validatePhone(params.Phone) + if err != nil { + return false, err + } + _, err = models.FindUserByPhoneAndAudience(db, params.Phone, aud) + } + + if err != nil && models.IsNotFoundError(err) { + return false, nil + } + } + return true, nil +} diff --git a/auth_v2.169.0/internal/api/otp_test.go b/auth_v2.169.0/internal/api/otp_test.go new file mode 100644 index 0000000..c72fbc3 --- /dev/null +++ b/auth_v2.169.0/internal/api/otp_test.go @@ -0,0 +1,311 @@ +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +type OtpTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestOtp(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &OtpTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *OtpTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + +} + +func (ts *OtpTestSuite) TestOtpPKCE() { + ts.Config.External.Phone.Enabled = true + testCodeChallenge := "testtesttesttesttesttesttestteststeststesttesttesttest" + + var buffer bytes.Buffer + cases := []struct { + desc string + params OtpParams + expected struct { + code int + response map[string]interface{} + } + }{ + { + desc: "Test (PKCE) Success Magiclink Otp", + params: OtpParams{ + Email: "test@example.com", + CreateUser: true, + CodeChallengeMethod: "s256", + CodeChallenge: testCodeChallenge, + }, + expected: struct { + code int + response map[string]interface{} + }{ + http.StatusOK, + make(map[string]interface{}), + }, + }, + { + desc: "Test (PKCE) Failure, no code challenge", + params: OtpParams{ + Email: "test@example.com", + CreateUser: true, + CodeChallengeMethod: "s256", + }, + expected: struct { + code int + response map[string]interface{} + }{ + http.StatusBadRequest, + map[string]interface{}{ + "code": float64(http.StatusBadRequest), + "error_code": ErrorCodeValidationFailed, + "msg": "PKCE flow requires code_challenge_method and code_challenge", + }, + }, + }, + { + desc: "Test (PKCE) Failure, no code challenge method", + params: OtpParams{ + Email: "test@example.com", + CreateUser: true, + CodeChallenge: testCodeChallenge, + }, + expected: struct { + code int + response map[string]interface{} + }{ + http.StatusBadRequest, + map[string]interface{}{ + "code": float64(http.StatusBadRequest), + "error_code": ErrorCodeValidationFailed, + "msg": "PKCE flow requires code_challenge_method and code_challenge", + }, + }, + }, + { + desc: "Test (PKCE) Success, phone with valid params", + params: OtpParams{ + Phone: "123456789", + CreateUser: true, + CodeChallengeMethod: "s256", + CodeChallenge: testCodeChallenge, + }, + expected: struct { + code int + response map[string]interface{} + }{ + http.StatusInternalServerError, + map[string]interface{}{ + "code": float64(http.StatusInternalServerError), + "msg": "Unable to get SMS provider", + }, + }, + }, + } + for _, c := range cases { + ts.Run(c.desc, func() { + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.params)) + + req := httptest.NewRequest(http.MethodPost, "/otp", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), c.expected.code, w.Code) + data := make(map[string]interface{}) + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + }) + } +} + +func (ts *OtpTestSuite) TestOtp() { + // Configured to allow testing of invalid channel params + ts.Config.External.Phone.Enabled = true + cases := []struct { + desc string + params OtpParams + expected struct { + code int + response map[string]interface{} + } + }{ + { + desc: "Test Success Magiclink Otp", + params: OtpParams{ + Email: "test@example.com", + CreateUser: true, + Data: map[string]interface{}{ + "somedata": "metadata", + }, + }, + expected: struct { + code int + response map[string]interface{} + }{ + http.StatusOK, + make(map[string]interface{}), + }, + }, + { + desc: "Test Failure Pass Both Email & Phone", + params: OtpParams{ + Email: "test@example.com", + Phone: "123456789", + CreateUser: true, + }, + expected: struct { + code int + response map[string]interface{} + }{ + http.StatusBadRequest, + map[string]interface{}{ + "code": float64(http.StatusBadRequest), + "error_code": ErrorCodeValidationFailed, + "msg": "Only an email address or phone number should be provided", + }, + }, + }, + { + desc: "Test Failure invalid channel param", + params: OtpParams{ + Phone: "123456789", + Channel: "invalidchannel", + CreateUser: true, + }, + expected: struct { + code int + response map[string]interface{} + }{ + http.StatusBadRequest, + map[string]interface{}{ + "code": float64(http.StatusBadRequest), + "error_code": ErrorCodeValidationFailed, + "msg": InvalidChannelError, + }, + }, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.params)) + + req := httptest.NewRequest(http.MethodPost, "/otp", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), c.expected.code, w.Code) + + data := make(map[string]interface{}) + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + // response should be empty + assert.Equal(ts.T(), data, c.expected.response) + }) + } +} + +func (ts *OtpTestSuite) TestNoSignupsForOtp() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "newuser@example.com", + "create_user": false, + })) + + req := httptest.NewRequest(http.MethodPost, "/otp", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusUnprocessableEntity, w.Code) + + data := make(map[string]interface{}) + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + // response should be empty + assert.Equal(ts.T(), data, map[string]interface{}{ + "code": float64(http.StatusUnprocessableEntity), + "error_code": ErrorCodeOTPDisabled, + "msg": "Signups not allowed for otp", + }) +} + +func (ts *OtpTestSuite) TestSubsequentOtp() { + ts.Config.SMTP.MaxFrequency = 0 + userEmail := "foo@example.com" + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": userEmail, + })) + + req := httptest.NewRequest(http.MethodPost, "/otp", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusOK, w.Code) + + newUser, err := models.FindUserByEmailAndAudience(ts.API.db, userEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), newUser.ConfirmationToken) + require.NotEmpty(ts.T(), newUser.ConfirmationSentAt) + require.Empty(ts.T(), newUser.RecoveryToken) + require.Empty(ts.T(), newUser.RecoverySentAt) + require.Empty(ts.T(), newUser.EmailConfirmedAt) + + // since the signup process hasn't been completed, + // subsequent requests for another magiclink should not create a recovery token + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": userEmail, + })) + + req = httptest.NewRequest(http.MethodPost, "/otp", &buffer) + req.Header.Set("Content-Type", "application/json") + + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusOK, w.Code) + + user, err := models.FindUserByEmailAndAudience(ts.API.db, userEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), user.ConfirmationToken) + require.NotEmpty(ts.T(), user.ConfirmationSentAt) + require.Empty(ts.T(), user.RecoveryToken) + require.Empty(ts.T(), user.RecoverySentAt) + require.Empty(ts.T(), user.EmailConfirmedAt) +} diff --git a/auth_v2.169.0/internal/api/pagination.go b/auth_v2.169.0/internal/api/pagination.go new file mode 100644 index 0000000..386f403 --- /dev/null +++ b/auth_v2.169.0/internal/api/pagination.go @@ -0,0 +1,64 @@ +package api + +import ( + "fmt" + "net/http" + "net/url" + "strconv" + + "github.com/supabase/auth/internal/models" +) + +const defaultPerPage = 50 + +func calculateTotalPages(perPage, total uint64) uint64 { + pages := total / perPage + if total%perPage > 0 { + return pages + 1 + } + return pages +} + +func addPaginationHeaders(w http.ResponseWriter, r *http.Request, p *models.Pagination) { + totalPages := calculateTotalPages(p.PerPage, p.Count) + url, _ := url.ParseRequestURI(r.URL.String()) + query := url.Query() + header := "" + if totalPages > p.Page { + query.Set("page", fmt.Sprintf("%v", p.Page+1)) + url.RawQuery = query.Encode() + header += "<" + url.String() + ">; rel=\"next\", " + } + query.Set("page", fmt.Sprintf("%v", totalPages)) + url.RawQuery = query.Encode() + header += "<" + url.String() + ">; rel=\"last\"" + + w.Header().Add("Link", header) + w.Header().Add("X-Total-Count", fmt.Sprintf("%v", p.Count)) +} + +func paginate(r *http.Request) (*models.Pagination, error) { + params := r.URL.Query() + queryPage := params.Get("page") + queryPerPage := params.Get("per_page") + var page uint64 = 1 + var perPage uint64 = defaultPerPage + var err error + if queryPage != "" { + page, err = strconv.ParseUint(queryPage, 10, 64) + if err != nil { + return nil, err + } + } + if queryPerPage != "" { + perPage, err = strconv.ParseUint(queryPerPage, 10, 64) + if err != nil { + return nil, err + } + } + + return &models.Pagination{ + Page: page, + PerPage: perPage, + }, nil +} diff --git a/auth_v2.169.0/internal/api/password.go b/auth_v2.169.0/internal/api/password.go new file mode 100644 index 0000000..73de368 --- /dev/null +++ b/auth_v2.169.0/internal/api/password.go @@ -0,0 +1,73 @@ +package api + +import ( + "context" + "fmt" + "strings" + + "github.com/sirupsen/logrus" +) + +// BCrypt hashed passwords have a 72 character limit +const MaxPasswordLength = 72 + +// WeakPasswordError encodes an error that a password does not meet strength +// requirements. It is handled specially in errors.go as it gets transformed to +// a HTTPError with a special weak_password field that encodes the Reasons +// slice. +type WeakPasswordError struct { + Message string `json:"message,omitempty"` + Reasons []string `json:"reasons,omitempty"` +} + +func (e *WeakPasswordError) Error() string { + return e.Message +} + +func (a *API) checkPasswordStrength(ctx context.Context, password string) error { + config := a.config + + if len(password) > MaxPasswordLength { + return badRequestError(ErrorCodeValidationFailed, fmt.Sprintf("Password cannot be longer than %v characters", MaxPasswordLength)) + } + + var messages, reasons []string + + if len(password) < config.Password.MinLength { + reasons = append(reasons, "length") + messages = append(messages, fmt.Sprintf("Password should be at least %d characters.", config.Password.MinLength)) + } + + for _, characterSet := range config.Password.RequiredCharacters { + if characterSet != "" && !strings.ContainsAny(password, characterSet) { + reasons = append(reasons, "characters") + + messages = append(messages, fmt.Sprintf("Password should contain at least one character of each: %s.", strings.Join(config.Password.RequiredCharacters, ", "))) + + break + } + } + + if config.Password.HIBP.Enabled { + pwned, err := a.hibpClient.Check(ctx, password) + if err != nil { + if config.Password.HIBP.FailClosed { + return internalServerError("Unable to perform password strength check with HaveIBeenPwned.org.").WithInternalError(err) + } else { + logrus.WithError(err).Warn("Unable to perform password strength check with HaveIBeenPwned.org, pwned passwords are being allowed") + } + } else if pwned { + reasons = append(reasons, "pwned") + messages = append(messages, "Password is known to be weak and easy to guess, please choose a different one.") + } + } + + if len(reasons) > 0 { + return &WeakPasswordError{ + Message: strings.Join(messages, " "), + Reasons: reasons, + } + } + + return nil +} diff --git a/auth_v2.169.0/internal/api/password_test.go b/auth_v2.169.0/internal/api/password_test.go new file mode 100644 index 0000000..f95f6f6 --- /dev/null +++ b/auth_v2.169.0/internal/api/password_test.go @@ -0,0 +1,117 @@ +package api + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/conf" +) + +func TestPasswordStrengthChecks(t *testing.T) { + examples := []struct { + MinLength int + RequiredCharacters []string + + Password string + Reasons []string + }{ + { + MinLength: 6, + Password: "12345", + Reasons: []string{ + "length", + }, + }, + { + MinLength: 6, + RequiredCharacters: []string{ + "a", + "b", + "c", + }, + Password: "123", + Reasons: []string{ + "length", + "characters", + }, + }, + { + MinLength: 6, + RequiredCharacters: []string{ + "a", + "b", + "c", + }, + Password: "a123", + Reasons: []string{ + "length", + "characters", + }, + }, + { + MinLength: 6, + RequiredCharacters: []string{ + "a", + "b", + "c", + }, + Password: "ab123", + Reasons: []string{ + "length", + "characters", + }, + }, + { + MinLength: 6, + RequiredCharacters: []string{ + "a", + "b", + "c", + }, + Password: "c123", + Reasons: []string{ + "length", + "characters", + }, + }, + { + MinLength: 6, + RequiredCharacters: []string{ + "a", + "b", + "c", + }, + Password: "abc123", + Reasons: nil, + }, + { + MinLength: 6, + RequiredCharacters: []string{}, + Password: "zZgXb5gzyCNrV36qwbOSbKVQsVJd28mC1TwRpeB0y6sFNICJyjD6bILKJMsjyKDzBdaY5tmi8zY9BWJYmt3vULLmyafjIDLYjy8qhETu0mS2jj1uQBgSAzJn9Zjm8EFa", + Reasons: nil, + }, + } + + for i, example := range examples { + api := &API{ + config: &conf.GlobalConfiguration{ + Password: conf.PasswordConfiguration{ + MinLength: example.MinLength, + RequiredCharacters: conf.PasswordRequiredCharacters(example.RequiredCharacters), + }, + }, + } + + err := api.checkPasswordStrength(context.Background(), example.Password) + + switch e := err.(type) { + case *WeakPasswordError: + require.Equal(t, e.Reasons, example.Reasons, "Example %d failed with wrong reasons", i) + case *HTTPError: + require.Equal(t, e.ErrorCode, ErrorCodeValidationFailed, "Example %d failed with wrong error code", i) + default: + require.NoError(t, err, "Example %d failed with error", i) + } + } +} diff --git a/auth_v2.169.0/internal/api/phone.go b/auth_v2.169.0/internal/api/phone.go new file mode 100644 index 0000000..5033888 --- /dev/null +++ b/auth_v2.169.0/internal/api/phone.go @@ -0,0 +1,169 @@ +package api + +import ( + "bytes" + "net/http" + "regexp" + "strings" + "text/template" + "time" + + "github.com/supabase/auth/internal/hooks" + + "github.com/pkg/errors" + "github.com/supabase/auth/internal/api/sms_provider" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +var e164Format = regexp.MustCompile("^[1-9][0-9]{1,14}$") + +const ( + phoneConfirmationOtp = "confirmation" + phoneReauthenticationOtp = "reauthentication" +) + +func validatePhone(phone string) (string, error) { + phone = formatPhoneNumber(phone) + if isValid := validateE164Format(phone); !isValid { + return "", badRequestError(ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)") + } + return phone, nil +} + +// validateE164Format checks if phone number follows the E.164 format +func validateE164Format(phone string) bool { + return e164Format.MatchString(phone) +} + +// formatPhoneNumber removes "+" and whitespaces in a phone number +func formatPhoneNumber(phone string) string { + return strings.ReplaceAll(strings.TrimPrefix(phone, "+"), " ", "") +} + +// sendPhoneConfirmation sends an otp to the user's phone number +func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, user *models.User, phone, otpType string, channel string) (string, error) { + config := a.config + + var token *string + var sentAt *time.Time + + includeFields := []string{} + switch otpType { + case phoneChangeVerification: + token = &user.PhoneChangeToken + sentAt = user.PhoneChangeSentAt + user.PhoneChange = phone + includeFields = append(includeFields, "phone_change", "phone_change_token", "phone_change_sent_at") + case phoneConfirmationOtp: + token = &user.ConfirmationToken + sentAt = user.ConfirmationSentAt + includeFields = append(includeFields, "confirmation_token", "confirmation_sent_at") + case phoneReauthenticationOtp: + token = &user.ReauthenticationToken + sentAt = user.ReauthenticationSentAt + includeFields = append(includeFields, "reauthentication_token", "reauthentication_sent_at") + default: + return "", internalServerError("invalid otp type") + } + + // intentionally keeping this before the test OTP, so that the behavior + // of regular and test OTPs is similar + if sentAt != nil && !sentAt.Add(config.Sms.MaxFrequency).Before(time.Now()) { + return "", tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, generateFrequencyLimitErrorMessage(sentAt, config.Sms.MaxFrequency)) + } + + now := time.Now() + + var otp, messageID string + + if testOTP, ok := config.Sms.GetTestOTP(phone, now); ok { + otp = testOTP + messageID = "test-otp" + } + + // not using test OTPs + if otp == "" { + // TODO(km): Deprecate this behaviour - rate limits should still be applied to autoconfirm + if !config.Sms.Autoconfirm { + // apply rate limiting before the sms is sent out + if ok := a.limiterOpts.Phone.Allow(); !ok { + return "", tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded") + } + } + otp = crypto.GenerateOtp(config.Sms.OtpLength) + + if config.Hook.SendSMS.Enabled { + input := hooks.SendSMSInput{ + User: user, + SMS: hooks.SMS{ + OTP: otp, + }, + } + output := hooks.SendSMSOutput{} + err := a.invokeHook(tx, r, &input, &output) + if err != nil { + return "", err + } + } else { + smsProvider, err := sms_provider.GetSmsProvider(*config) + if err != nil { + return "", internalServerError("Unable to get SMS provider").WithInternalError(err) + } + message, err := generateSMSFromTemplate(config.Sms.SMSTemplate, otp) + if err != nil { + return "", internalServerError("error generating sms template").WithInternalError(err) + } + messageID, err := smsProvider.SendMessage(phone, message, channel, otp) + if err != nil { + return messageID, unprocessableEntityError(ErrorCodeSMSSendFailed, "Error sending %s OTP to provider: %v", otpType, err) + } + } + } + + *token = crypto.GenerateTokenHash(phone, otp) + + switch otpType { + case phoneConfirmationOtp: + user.ConfirmationSentAt = &now + case phoneChangeVerification: + user.PhoneChangeSentAt = &now + case phoneReauthenticationOtp: + user.ReauthenticationSentAt = &now + } + + if err := tx.UpdateOnly(user, includeFields...); err != nil { + return messageID, errors.Wrap(err, "Database error updating user for phone") + } + + var ottErr error + switch otpType { + case phoneConfirmationOtp: + if err := models.CreateOneTimeToken(tx, user.ID, user.GetPhone(), user.ConfirmationToken, models.ConfirmationToken); err != nil { + ottErr = errors.Wrap(err, "Database error creating confirmation token for phone") + } + case phoneChangeVerification: + if err := models.CreateOneTimeToken(tx, user.ID, user.PhoneChange, user.PhoneChangeToken, models.PhoneChangeToken); err != nil { + ottErr = errors.Wrap(err, "Database error creating phone change token") + } + case phoneReauthenticationOtp: + if err := models.CreateOneTimeToken(tx, user.ID, user.GetPhone(), user.ReauthenticationToken, models.ReauthenticationToken); err != nil { + ottErr = errors.Wrap(err, "Database error creating reauthentication token for phone") + } + } + if ottErr != nil { + return messageID, internalServerError("error creating one time token").WithInternalError(ottErr) + } + return messageID, nil +} + +func generateSMSFromTemplate(SMSTemplate *template.Template, otp string) (string, error) { + var message bytes.Buffer + if err := SMSTemplate.Execute(&message, struct { + Code string + }{Code: otp}); err != nil { + return "", err + } + return message.String(), nil +} diff --git a/auth_v2.169.0/internal/api/phone_test.go b/auth_v2.169.0/internal/api/phone_test.go new file mode 100644 index 0000000..adc50f1 --- /dev/null +++ b/auth_v2.169.0/internal/api/phone_test.go @@ -0,0 +1,443 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/api/sms_provider" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +type PhoneTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +type TestSmsProvider struct { + mock.Mock + + SentMessages int +} + +func (t *TestSmsProvider) SendMessage(phone, message, channel, otp string) (string, error) { + t.SentMessages += 1 + return "", nil +} +func (t *TestSmsProvider) VerifyOTP(phone, otp string) error { + return nil +} + +func TestPhone(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &PhoneTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *PhoneTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + // Create user + u, err := models.NewUser("123456789", "", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") +} + +func (ts *PhoneTestSuite) TestValidateE164Format() { + isValid := validateE164Format("0123456789") + assert.Equal(ts.T(), false, isValid) +} + +func (ts *PhoneTestSuite) TestFormatPhoneNumber() { + actual := formatPhoneNumber("+1 23456789 ") + assert.Equal(ts.T(), "123456789", actual) +} + +func doTestSendPhoneConfirmation(ts *PhoneTestSuite, useTestOTP bool) { + u, err := models.FindUserByPhoneAndAudience(ts.API.db, "123456789", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + req, err := http.NewRequest("POST", "http://localhost:9998/otp", nil) + require.NoError(ts.T(), err) + cases := []struct { + desc string + otpType string + expected error + }{ + { + desc: "send confirmation otp", + otpType: phoneConfirmationOtp, + expected: nil, + }, + { + desc: "send phone_change otp", + otpType: phoneChangeVerification, + expected: nil, + }, + { + desc: "send recovery otp", + otpType: phoneReauthenticationOtp, + expected: nil, + }, + { + desc: "send invalid otp type ", + otpType: "invalid otp type", + expected: internalServerError("invalid otp type"), + }, + } + + if useTestOTP { + ts.API.config.Sms.TestOTP = map[string]string{ + "123456789": "123456", + } + } else { + ts.API.config.Sms.TestOTP = nil + } + + for _, c := range cases { + ts.Run(c.desc, func() { + provider := &TestSmsProvider{} + sms_provider.MockProvider = provider + + _, err = ts.API.sendPhoneConfirmation(req, ts.API.db, u, "123456789", c.otpType, sms_provider.SMSProvider) + require.Equal(ts.T(), c.expected, err) + u, err = models.FindUserByPhoneAndAudience(ts.API.db, "123456789", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + if c.expected == nil { + if useTestOTP { + require.Equal(ts.T(), provider.SentMessages, 0) + } else { + require.Equal(ts.T(), provider.SentMessages, 1) + } + } + + switch c.otpType { + case phoneConfirmationOtp: + require.NotEmpty(ts.T(), u.ConfirmationToken) + require.NotEmpty(ts.T(), u.ConfirmationSentAt) + case phoneChangeVerification: + require.NotEmpty(ts.T(), u.PhoneChangeToken) + require.NotEmpty(ts.T(), u.PhoneChangeSentAt) + case phoneReauthenticationOtp: + require.NotEmpty(ts.T(), u.ReauthenticationToken) + require.NotEmpty(ts.T(), u.ReauthenticationSentAt) + default: + } + }) + } + // Reset at end of test + ts.API.config.Sms.TestOTP = nil + +} + +func (ts *PhoneTestSuite) TestSendPhoneConfirmation() { + doTestSendPhoneConfirmation(ts, false) +} + +func (ts *PhoneTestSuite) TestSendPhoneConfirmationWithTestOTP() { + doTestSendPhoneConfirmation(ts, true) +} + +func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { + u, err := models.FindUserByPhoneAndAudience(ts.API.db, "123456789", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + now := time.Now() + u.PhoneConfirmedAt = &now + require.NoError(ts.T(), ts.API.db.Update(u), "Error updating new test user") + + s, err := models.NewSession(u.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(s)) + + req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil) + token, _, err := ts.API.generateAccessToken(req, ts.API.db, u, &s.ID, models.PasswordGrant) + require.NoError(ts.T(), err) + + cases := []struct { + desc string + endpoint string + method string + header string + body map[string]string + expected map[string]interface{} + }{ + { + desc: "Signup", + endpoint: "/signup", + method: http.MethodPost, + header: "", + body: map[string]string{ + "phone": "1234567890", + "password": "testpassword", + }, + expected: map[string]interface{}{ + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", + }, + }, + { + desc: "Sms OTP", + endpoint: "/otp", + method: http.MethodPost, + header: "", + body: map[string]string{ + "phone": "123456789", + }, + expected: map[string]interface{}{ + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", + }, + }, + { + desc: "Phone change", + endpoint: "/user", + method: http.MethodPut, + header: token, + body: map[string]string{ + "phone": "111111111", + }, + expected: map[string]interface{}{ + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", + }, + }, + { + desc: "Reauthenticate", + endpoint: "/reauthenticate", + method: http.MethodGet, + header: "", + body: nil, + expected: map[string]interface{}{ + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", + }, + }, + } + + smsProviders := []string{"twilio", "messagebird", "textlocal", "vonage"} + ts.Config.External.Phone.Enabled = true + ts.Config.Sms.Twilio.AccountSid = "" + ts.Config.Sms.Messagebird.AccessKey = "" + ts.Config.Sms.Textlocal.ApiKey = "" + ts.Config.Sms.Vonage.ApiKey = "" + + for _, c := range cases { + for _, provider := range smsProviders { + ts.Config.Sms.Provider = provider + desc := fmt.Sprintf("[%v] %v", provider, c.desc) + ts.Run(desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + + req := httptest.NewRequest(c.method, "http://localhost"+c.endpoint, &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), c.expected["code"], w.Code) + + body := w.Body.String() + require.True(ts.T(), + strings.Contains(body, "Unable to get SMS provider") || + strings.Contains(body, "Error finding SMS provider") || + strings.Contains(body, "Failed to get SMS provider"), + "unexpected body message %q", body, + ) + }) + } + } +} +func (ts *PhoneTestSuite) TestSendSMSHook() { + u, err := models.FindUserByPhoneAndAudience(ts.API.db, "123456789", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + now := time.Now() + u.PhoneConfirmedAt = &now + require.NoError(ts.T(), ts.API.db.Update(u), "Error updating new test user") + + s, err := models.NewSession(u.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(s)) + + req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil) + token, _, err := ts.API.generateAccessToken(req, ts.API.db, u, &s.ID, models.PasswordGrant) + require.NoError(ts.T(), err) + + // We setup a job table to enqueue SMS requests to send. Similar in spirit to the pg_boss postgres extension + createJobsTableSQL := `CREATE TABLE job_queue ( + id serial PRIMARY KEY, + job_type text, + payload jsonb, + status text DEFAULT 'pending', -- Possible values: 'pending', 'processing', 'completed', 'failed' + created_at timestamp without time zone DEFAULT NOW() + );` + require.NoError(ts.T(), ts.API.db.RawQuery(createJobsTableSQL).Exec()) + + type sendSMSHookTestCase struct { + desc string + uri string + endpoint string + method string + header string + body map[string]string + hookFunctionSQL string + expectedCode int + expectToken bool + hookFunctionIdentifier string + } + cases := []sendSMSHookTestCase{ + { + desc: "Phone signup using Hook", + endpoint: "/signup", + method: http.MethodPost, + uri: "pg-functions://postgres/auth/send_sms_signup", + hookFunctionSQL: ` + create or replace function send_sms_signup(input jsonb) + returns json as $$ + begin + insert into job_queue(job_type, payload) + values ('sms_signup', input); + return input; + end; $$ language plpgsql;`, + header: "", + body: map[string]string{ + "phone": "1234567890", + "password": "testpassword", + }, + expectedCode: http.StatusOK, + hookFunctionIdentifier: "send_sms_signup(input jsonb)", + }, + { + desc: "SMS OTP sign in using hook", + endpoint: "/otp", + method: http.MethodPost, + uri: "pg-functions://postgres/auth/send_sms_otp", + hookFunctionSQL: ` + create or replace function send_sms_otp(input jsonb) + returns json as $$ + begin + insert into job_queue(job_type, payload) + values ('sms_signup', input); + return input; + end; $$ language plpgsql;`, + header: "", + body: map[string]string{ + "phone": "123456789", + }, + expectToken: false, + expectedCode: http.StatusOK, + hookFunctionIdentifier: "send_sms_otp(input jsonb)", + }, + { + desc: "Phone Change", + endpoint: "/user", + method: http.MethodPut, + uri: "pg-functions://postgres/auth/send_sms_phone_change", + hookFunctionSQL: ` + create or replace function send_sms_phone_change(input jsonb) + returns json as $$ + begin + insert into job_queue(job_type, payload) + values ('phone_change', input); + return input; + end; $$ language plpgsql;`, + header: token, + body: map[string]string{ + "phone": "111111111", + }, + expectToken: true, + expectedCode: http.StatusOK, + hookFunctionIdentifier: "send_sms_phone_change(input jsonb)", + }, + { + desc: "Reauthenticate", + endpoint: "/reauthenticate", + method: http.MethodGet, + uri: "pg-functions://postgres/auth/reauthenticate", + hookFunctionSQL: ` + create or replace function reauthenticate(input jsonb) + returns json as $$ + begin + return input; + end; $$ language plpgsql;`, + header: "", + body: nil, + expectToken: true, + expectedCode: http.StatusOK, + hookFunctionIdentifier: "reauthenticate(input jsonb)", + }, + { + desc: "SMS OTP Hook (Error)", + endpoint: "/otp", + method: http.MethodPost, + uri: "pg-functions://postgres/auth/send_sms_otp_failure", + hookFunctionSQL: ` + create or replace function send_sms_otp(input jsonb) + returns json as $$ + begin + RAISE EXCEPTION 'Intentional Error for Testing'; + end; $$ language plpgsql;`, + header: "", + body: map[string]string{ + "phone": "123456789", + }, + expectToken: false, + expectedCode: http.StatusInternalServerError, + hookFunctionIdentifier: "send_sms_otp_failure(input jsonb)", + }, + } + + for _, c := range cases { + ts.T().Run(c.desc, func(t *testing.T) { + + ts.Config.External.Phone.Enabled = true + ts.Config.Hook.SendSMS.Enabled = true + ts.Config.Hook.SendSMS.URI = c.uri + // Disable FrequencyLimit to allow back to back sending + ts.Config.Sms.MaxFrequency = 0 * time.Second + require.NoError(ts.T(), ts.Config.Hook.SendSMS.PopulateExtensibilityPoint()) + + require.NoError(t, ts.API.db.RawQuery(c.hookFunctionSQL).Exec()) + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + req := httptest.NewRequest(c.method, "http://localhost"+c.endpoint, &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(t, c.expectedCode, w.Code, "Unexpected HTTP status code") + + // Delete the function and reset env + cleanupHookSQL := fmt.Sprintf("drop function if exists %s", ts.Config.Hook.SendSMS.HookName) + require.NoError(t, ts.API.db.RawQuery(cleanupHookSQL).Exec()) + ts.Config.Hook.SendSMS.Enabled = false + ts.Config.Sms.MaxFrequency = 1 * time.Second + }) + } + + // Cleanup + deleteJobsTableSQL := `drop table if exists job_queue` + require.NoError(ts.T(), ts.API.db.RawQuery(deleteJobsTableSQL).Exec()) + +} diff --git a/auth_v2.169.0/internal/api/pkce.go b/auth_v2.169.0/internal/api/pkce.go new file mode 100644 index 0000000..5ac7566 --- /dev/null +++ b/auth_v2.169.0/internal/api/pkce.go @@ -0,0 +1,98 @@ +package api + +import ( + "regexp" + + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +const ( + PKCEPrefix = "pkce_" + MinCodeChallengeLength = 43 + MaxCodeChallengeLength = 128 + InvalidPKCEParamsErrorMessage = "PKCE flow requires code_challenge_method and code_challenge" +) + +var codeChallengePattern = regexp.MustCompile("^[a-zA-Z._~0-9-]+$") + +func isValidCodeChallenge(codeChallenge string) (bool, error) { + // See RFC 7636 Section 4.2: https://www.rfc-editor.org/rfc/rfc7636#section-4.2 + switch codeChallengeLength := len(codeChallenge); { + case codeChallengeLength < MinCodeChallengeLength, codeChallengeLength > MaxCodeChallengeLength: + return false, badRequestError(ErrorCodeValidationFailed, "code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength) + case !codeChallengePattern.MatchString(codeChallenge): + return false, badRequestError(ErrorCodeValidationFailed, "code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes") + default: + return true, nil + } +} + +func addFlowPrefixToToken(token string, flowType models.FlowType) string { + if isPKCEFlow(flowType) { + return flowType.String() + "_" + token + } else if isImplicitFlow(flowType) { + return token + } + return token +} + +func issueAuthCode(tx *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod) (string, error) { + flowState, err := models.FindFlowStateByUserID(tx, user.ID.String(), authenticationMethod) + if err != nil && models.IsNotFoundError(err) { + return "", unprocessableEntityError(ErrorCodeFlowStateNotFound, "No valid flow state found for user.") + } else if err != nil { + return "", err + } + if err := flowState.RecordAuthCodeIssuedAtTime(tx); err != nil { + return "", err + } + + return flowState.AuthCode, nil +} + +func isPKCEFlow(flowType models.FlowType) bool { + return flowType == models.PKCEFlow +} + +func isImplicitFlow(flowType models.FlowType) bool { + return flowType == models.ImplicitFlow +} + +func validatePKCEParams(codeChallengeMethod, codeChallenge string) error { + switch true { + case (codeChallenge == "") != (codeChallengeMethod == ""): + return badRequestError(ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage) + case codeChallenge != "": + if valid, err := isValidCodeChallenge(codeChallenge); !valid { + return err + } + default: + // if both params are empty, just return nil + return nil + } + return nil +} + +func getFlowFromChallenge(codeChallenge string) models.FlowType { + if codeChallenge != "" { + return models.PKCEFlow + } else { + return models.ImplicitFlow + } +} + +// Should only be used with Auth Code of PKCE Flows +func generateFlowState(tx *storage.Connection, providerType string, authenticationMethod models.AuthenticationMethod, codeChallengeMethodParam string, codeChallenge string, userID *uuid.UUID) (*models.FlowState, error) { + codeChallengeMethod, err := models.ParseCodeChallengeMethod(codeChallengeMethodParam) + if err != nil { + return nil, err + } + flowState := models.NewFlowState(providerType, codeChallenge, codeChallengeMethod, authenticationMethod, userID) + if err := tx.Create(flowState); err != nil { + return nil, err + } + return flowState, nil + +} diff --git a/auth_v2.169.0/internal/api/provider/apple.go b/auth_v2.169.0/internal/api/provider/apple.go new file mode 100644 index 0000000..508eaf1 --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/apple.go @@ -0,0 +1,144 @@ +package provider + +import ( + "context" + "encoding/json" + "net/url" + "strconv" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const IssuerApple = "https://appleid.apple.com" + +// AppleProvider stores the custom config for apple provider +type AppleProvider struct { + *oauth2.Config + oidc *oidc.Provider +} + +type IsPrivateEmail bool + +// Apple returns an is_private_email field that could be a string or boolean value so we need to implement a custom unmarshaler +// https://developer.apple.com/documentation/sign_in_with_apple/sign_in_with_apple_rest_api/authenticating_users_with_sign_in_with_apple +func (b *IsPrivateEmail) UnmarshalJSON(data []byte) error { + var boolVal bool + if err := json.Unmarshal(data, &boolVal); err == nil { + *b = IsPrivateEmail(boolVal) + return nil + } + + // ignore the error and try to unmarshal as a string + var strVal string + if err := json.Unmarshal(data, &strVal); err != nil { + return err + } + + var err error + boolVal, err = strconv.ParseBool(strVal) + if err != nil { + return err + } + + *b = IsPrivateEmail(boolVal) + return nil +} + +type appleName struct { + FirstName string `json:"firstName"` + LastName string `json:"lastName"` +} + +type appleUser struct { + Name appleName `json:"name"` + Email string `json:"email"` +} + +// NewAppleProvider creates a Apple account provider. +func NewAppleProvider(ctx context.Context, ext conf.OAuthProviderConfiguration) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + if ext.URL != "" { + logrus.Warn("Apple OAuth provider has URL config set which is ignored (check GOTRUE_EXTERNAL_APPLE_URL)") + } + + oidcProvider, err := oidc.NewProvider(ctx, IssuerApple) + if err != nil { + return nil, err + } + + return &AppleProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oidcProvider.Endpoint(), + Scopes: []string{ + "email", + "name", + }, + RedirectURL: ext.RedirectURI, + }, + oidc: oidcProvider, + }, nil +} + +// GetOAuthToken returns the apple provider access token +func (p AppleProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + opts := []oauth2.AuthCodeOption{ + oauth2.SetAuthURLParam("client_id", p.ClientID), + oauth2.SetAuthURLParam("secret", p.ClientSecret), + } + return p.Exchange(context.Background(), code, opts...) +} + +func (p AppleProvider) AuthCodeURL(state string, args ...oauth2.AuthCodeOption) string { + opts := make([]oauth2.AuthCodeOption, 0, 1) + opts = append(opts, oauth2.SetAuthURLParam("response_mode", "form_post")) + authURL := p.Config.AuthCodeURL(state, opts...) + if authURL != "" { + if u, err := url.Parse(authURL); err != nil { + u.RawQuery = strings.ReplaceAll(u.RawQuery, "+", "%20") + authURL = u.String() + } + } + return authURL +} + +// GetUserData returns the user data fetched from the apple provider +func (p AppleProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + idToken := tok.Extra("id_token") + if tok.AccessToken == "" || idToken == nil { + // Apple returns user data only the first time + return &UserProvidedData{}, nil + } + + _, data, err := ParseIDToken(ctx, p.oidc, &oidc.Config{ + ClientID: p.ClientID, + }, idToken.(string), ParseIDTokenOptions{ + AccessToken: tok.AccessToken, + }) + if err != nil { + return nil, err + } + + return data, nil +} + +// ParseUser parses the apple user's info +func (p AppleProvider) ParseUser(data string, userData *UserProvidedData) error { + u := &appleUser{} + err := json.Unmarshal([]byte(data), u) + if err != nil { + return err + } + + userData.Metadata.Name = strings.TrimSpace(u.Name.FirstName + " " + u.Name.LastName) + userData.Metadata.FullName = strings.TrimSpace(u.Name.FirstName + " " + u.Name.LastName) + return nil +} diff --git a/auth_v2.169.0/internal/api/provider/azure.go b/auth_v2.169.0/internal/api/provider/azure.go new file mode 100644 index 0000000..4a341f4 --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/azure.go @@ -0,0 +1,164 @@ +package provider + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "regexp" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const IssuerAzureCommon = "https://login.microsoftonline.com/common/v2.0" +const IssuerAzureOrganizations = "https://login.microsoftonline.com/organizations/v2.0" + +// IssuerAzureMicrosoft is the OIDC issuer for microsoft.com accounts: +// https://learn.microsoft.com/en-us/azure/active-directory/develop/id-token-claims-reference#payload-claims +const IssuerAzureMicrosoft = "https://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/v2.0" + +const ( + defaultAzureAuthBase = "login.microsoftonline.com/common" +) + +type azureProvider struct { + *oauth2.Config + + // ExpectedIssuer contains the OIDC issuer that should be expected when + // the authorize flow completes. For example, when using the "common" + // endpoint the authorization flow will end with an ID token that + // contains any issuer. In this case, ExpectedIssuer is an empty + // string, because any issuer is allowed. But if a developer sets up a + // tenant-specific authorization endpoint, then we must ensure that the + // ID token received is issued by that specific issuer, and so + // ExpectedIssuer contains the issuer URL of that tenant. + ExpectedIssuer string +} + +var azureIssuerRegexp = regexp.MustCompile("^https://login[.]microsoftonline[.]com/([^/]+)/v2[.]0/?$") +var azureCIAMIssuerRegexp = regexp.MustCompile("^https://[a-z0-9-]+[.]ciamlogin[.]com/([^/]+)/v2[.]0/?$") + +func IsAzureIssuer(issuer string) bool { + return azureIssuerRegexp.MatchString(issuer) +} + +func IsAzureCIAMIssuer(issuer string) bool { + return azureCIAMIssuerRegexp.MatchString(issuer) +} + +// NewAzureProvider creates a Azure account provider. +func NewAzureProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + oauthScopes := []string{"openid"} + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + authHost := chooseHost(ext.URL, defaultAzureAuthBase) + expectedIssuer := "" + + if ext.URL != "" { + expectedIssuer = authHost + "/v2.0" + + if !IsAzureIssuer(expectedIssuer) || !IsAzureCIAMIssuer(expectedIssuer) || expectedIssuer == IssuerAzureCommon || expectedIssuer == IssuerAzureOrganizations { + // in tests, the URL is a local server which should not + // be the expected issuer + // also, IssuerAzure (common) never actually issues any + // ID tokens so it needs to be ignored + expectedIssuer = "" + } + } + + return &azureProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: authHost + "/oauth2/v2.0/authorize", + TokenURL: authHost + "/oauth2/v2.0/token", + }, + RedirectURL: ext.RedirectURI, + Scopes: oauthScopes, + }, + ExpectedIssuer: expectedIssuer, + }, nil +} + +func (g azureProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +func DetectAzureIDTokenIssuer(ctx context.Context, idToken string) (string, error) { + var payload struct { + Issuer string `json:"iss"` + } + + parts := strings.Split(idToken, ".") + if len(parts) != 3 { + return "", fmt.Errorf("azure: invalid ID token") + } + + payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return "", fmt.Errorf("azure: invalid ID token %w", err) + } + + if err := json.Unmarshal(payloadBytes, &payload); err != nil { + return "", fmt.Errorf("azure: invalid ID token %w", err) + } + + return payload.Issuer, nil +} + +func (g azureProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + idToken := tok.Extra("id_token") + + if idToken != nil { + issuer, err := DetectAzureIDTokenIssuer(ctx, idToken.(string)) + if err != nil { + return nil, err + } + + // Allow basic Azure issuers, except when the expected issuer + // is configured to be the Azure CIAM issuer, allow CIAM + // issuers to pass. + if !IsAzureIssuer(issuer) && (IsAzureCIAMIssuer(g.ExpectedIssuer) && !IsAzureCIAMIssuer(issuer)) { + return nil, fmt.Errorf("azure: ID token issuer not valid %q", issuer) + } + + if g.ExpectedIssuer != "" && issuer != g.ExpectedIssuer { + // Since ExpectedIssuer was set, then the developer had + // setup GoTrue to use the tenant-specific + // authorization endpoint, which in-turn means that + // only those tenant's ID tokens will be accepted. + return nil, fmt.Errorf("azure: ID token issuer %q does not match expected issuer %q", issuer, g.ExpectedIssuer) + } + + provider, err := oidc.NewProvider(ctx, issuer) + if err != nil { + return nil, err + } + + _, data, err := ParseIDToken(ctx, provider, &oidc.Config{ + ClientID: g.ClientID, + }, idToken.(string), ParseIDTokenOptions{ + AccessToken: tok.AccessToken, + }) + if err != nil { + return nil, err + } + + return data, nil + } + + // Only ID tokens supported, UserInfo endpoint has a history of being less secure. + + return nil, fmt.Errorf("azure: no OIDC ID token present in response") +} diff --git a/auth_v2.169.0/internal/api/provider/azure_test.go b/auth_v2.169.0/internal/api/provider/azure_test.go new file mode 100644 index 0000000..316cb08 --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/azure_test.go @@ -0,0 +1,29 @@ +package provider + +import "testing" + +func TestIsAzureIssuer(t *testing.T) { + positiveExamples := []string{ + "https://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/v2.0", + "https://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/v2.0/", + "https://login.microsoftonline.com/common/v2.0", + } + + negativeExamples := []string{ + "http://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/v2.0", + "https://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/v2.0?something=else", + "https://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/v2.0/extra", + } + + for _, example := range positiveExamples { + if !IsAzureIssuer(example) { + t.Errorf("Example %q should be treated as a valid Azure issuer", example) + } + } + + for _, example := range negativeExamples { + if IsAzureIssuer(example) { + t.Errorf("Example %q should be treated as not a valid Azure issuer", example) + } + } +} diff --git a/auth_v2.169.0/internal/api/provider/bitbucket.go b/auth_v2.169.0/internal/api/provider/bitbucket.go new file mode 100644 index 0000000..e5fae5c --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/bitbucket.go @@ -0,0 +1,104 @@ +package provider + +import ( + "context" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const ( + defaultBitbucketAuthBase = "bitbucket.org" + defaultBitbucketAPIBase = "api.bitbucket.org" +) + +type bitbucketProvider struct { + *oauth2.Config + APIPath string +} + +type bitbucketUser struct { + Name string `json:"display_name"` + ID string `json:"uuid"` + Avatar struct { + Href string `json:"href"` + } `json:"avatar"` +} + +type bitbucketEmail struct { + Email string `json:"email"` + Primary bool `json:"is_primary"` + Verified bool `json:"is_confirmed"` +} + +type bitbucketEmails struct { + Values []bitbucketEmail `json:"values"` +} + +// NewBitbucketProvider creates a Bitbucket account provider. +func NewBitbucketProvider(ext conf.OAuthProviderConfiguration) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + authHost := chooseHost(ext.URL, defaultBitbucketAuthBase) + apiPath := chooseHost(ext.URL, defaultBitbucketAPIBase) + "/2.0" + + return &bitbucketProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: authHost + "/site/oauth2/authorize", + TokenURL: authHost + "/site/oauth2/access_token", + }, + RedirectURL: ext.RedirectURI, + Scopes: []string{"account", "email"}, + }, + APIPath: apiPath, + }, nil +} + +func (g bitbucketProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +func (g bitbucketProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var u bitbucketUser + if err := makeRequest(ctx, tok, g.Config, g.APIPath+"/user", &u); err != nil { + return nil, err + } + + data := &UserProvidedData{} + + var emails bitbucketEmails + if err := makeRequest(ctx, tok, g.Config, g.APIPath+"/user/emails", &emails); err != nil { + return nil, err + } + + if len(emails.Values) > 0 { + for _, e := range emails.Values { + if e.Email != "" { + data.Emails = append(data.Emails, Email{ + Email: e.Email, + Verified: e.Verified, + Primary: e.Primary, + }) + } + } + } + + data.Metadata = &Claims{ + Issuer: g.APIPath, + Subject: u.ID, + Name: u.Name, + Picture: u.Avatar.Href, + + // To be deprecated + AvatarURL: u.Avatar.Href, + FullName: u.Name, + ProviderId: u.ID, + } + + return data, nil +} diff --git a/auth_v2.169.0/internal/api/provider/discord.go b/auth_v2.169.0/internal/api/provider/discord.go new file mode 100644 index 0000000..50d413b --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/discord.go @@ -0,0 +1,120 @@ +package provider + +import ( + "context" + "fmt" + "strconv" + "strings" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const ( + defaultDiscordAPIBase = "discord.com" +) + +type discordProvider struct { + *oauth2.Config + APIPath string +} + +type discordUser struct { + Avatar string `json:"avatar"` + Discriminator string `json:"discriminator"` + Email string `json:"email"` + ID string `json:"id"` + Name string `json:"username"` + GlobalName string `json:"global_name"` + Verified bool `json:"verified"` +} + +// NewDiscordProvider creates a Discord account provider. +func NewDiscordProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + apiPath := chooseHost(ext.URL, defaultDiscordAPIBase) + "/api" + + oauthScopes := []string{ + "email", + "identify", + } + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + return &discordProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: apiPath + "/oauth2/authorize", + TokenURL: apiPath + "/oauth2/token", + }, + Scopes: oauthScopes, + RedirectURL: ext.RedirectURI, + }, + APIPath: apiPath, + }, nil +} + +func (g discordProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +func (g discordProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var u discordUser + if err := makeRequest(ctx, tok, g.Config, g.APIPath+"/users/@me", &u); err != nil { + return nil, err + } + + data := &UserProvidedData{} + if u.Email != "" { + data.Emails = []Email{{ + Email: u.Email, + Verified: u.Verified, + Primary: true, + }} + } + + var avatarURL string + extension := "png" + if u.Avatar == "" { + if intDiscriminator, err := strconv.Atoi(u.Discriminator); err != nil { + return nil, err + } else { + // https://discord.com/developers/docs/reference#image-formatting-cdn-endpoints: + // In the case of the Default User Avatar endpoint, the value for + // user_discriminator in the path should be the user's discriminator modulo 5 + avatarURL = fmt.Sprintf("https://cdn.discordapp.com/embed/avatars/%d.%s", intDiscriminator%5, extension) + } + } else { + // https://discord.com/developers/docs/reference#image-formatting: + // "In the case of endpoints that support GIFs, the hash will begin with a_ + // if it is available in GIF format." + if strings.HasPrefix(u.Avatar, "a_") { + extension = "gif" + } + avatarURL = fmt.Sprintf("https://cdn.discordapp.com/avatars/%s/%s.%s", u.ID, u.Avatar, extension) + } + + data.Metadata = &Claims{ + Issuer: g.APIPath, + Subject: u.ID, + Name: fmt.Sprintf("%v#%v", u.Name, u.Discriminator), + Picture: avatarURL, + CustomClaims: map[string]interface{}{ + "global_name": u.GlobalName, + }, + + // To be deprecated + AvatarURL: avatarURL, + FullName: u.Name, + ProviderId: u.ID, + } + + return data, nil +} diff --git a/auth_v2.169.0/internal/api/provider/errors.go b/auth_v2.169.0/internal/api/provider/errors.go new file mode 100644 index 0000000..67a20ea --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/errors.go @@ -0,0 +1,49 @@ +package provider + +import "fmt" + +type HTTPError struct { + Code int `json:"code"` + Message string `json:"msg"` + InternalError error `json:"-"` + InternalMessage string `json:"-"` + ErrorID string `json:"error_id,omitempty"` +} + +func (e *HTTPError) Error() string { + if e.InternalMessage != "" { + return e.InternalMessage + } + return fmt.Sprintf("%d: %s", e.Code, e.Message) +} + +func (e *HTTPError) Is(target error) bool { + return e.Error() == target.Error() +} + +// Cause returns the root cause error +func (e *HTTPError) Cause() error { + if e.InternalError != nil { + return e.InternalError + } + return e +} + +// WithInternalError adds internal error information to the error +func (e *HTTPError) WithInternalError(err error) *HTTPError { + e.InternalError = err + return e +} + +// WithInternalMessage adds internal message information to the error +func (e *HTTPError) WithInternalMessage(fmtString string, args ...interface{}) *HTTPError { + e.InternalMessage = fmt.Sprintf(fmtString, args...) + return e +} + +func httpError(code int, fmtString string, args ...interface{}) *HTTPError { + return &HTTPError{ + Code: code, + Message: fmt.Sprintf(fmtString, args...), + } +} diff --git a/auth_v2.169.0/internal/api/provider/facebook.go b/auth_v2.169.0/internal/api/provider/facebook.go new file mode 100644 index 0000000..e73c419 --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/facebook.go @@ -0,0 +1,112 @@ +package provider + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "strings" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const IssuerFacebook = "https://www.facebook.com" + +const ( + defaultFacebookAuthBase = "www.facebook.com" + defaultFacebookTokenBase = "graph.facebook.com" //#nosec G101 -- Not a secret value. + defaultFacebookAPIBase = "graph.facebook.com" +) + +type facebookProvider struct { + *oauth2.Config + ProfileURL string +} + +type facebookUser struct { + ID string `json:"id"` + Email string `json:"email"` + FirstName string `json:"first_name"` + LastName string `json:"last_name"` + Alias string `json:"name"` + Avatar struct { + Data struct { + URL string `json:"url"` + } `json:"data"` + } `json:"picture"` +} + +// NewFacebookProvider creates a Facebook account provider. +func NewFacebookProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + authHost := chooseHost(ext.URL, defaultFacebookAuthBase) + tokenHost := chooseHost(ext.URL, defaultFacebookTokenBase) + profileURL := chooseHost(ext.URL, defaultFacebookAPIBase) + "/me?fields=email,first_name,last_name,name,picture" + + oauthScopes := []string{ + "email", + } + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + return &facebookProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + RedirectURL: ext.RedirectURI, + Endpoint: oauth2.Endpoint{ + AuthURL: authHost + "/dialog/oauth", + TokenURL: tokenHost + "/oauth/access_token", + }, + Scopes: oauthScopes, + }, + ProfileURL: profileURL, + }, nil +} + +func (p facebookProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return p.Exchange(context.Background(), code) +} + +func (p facebookProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + hash := hmac.New(sha256.New, []byte(p.Config.ClientSecret)) + hash.Write([]byte(tok.AccessToken)) + appsecretProof := hex.EncodeToString(hash.Sum(nil)) + + var u facebookUser + url := p.ProfileURL + "&appsecret_proof=" + appsecretProof + if err := makeRequest(ctx, tok, p.Config, url, &u); err != nil { + return nil, err + } + + data := &UserProvidedData{} + if u.Email != "" { + data.Emails = []Email{{ + Email: u.Email, + Verified: true, + Primary: true, + }} + } + + data.Metadata = &Claims{ + Issuer: p.ProfileURL, + Subject: u.ID, + Name: strings.TrimSpace(u.FirstName + " " + u.LastName), + NickName: u.Alias, + Picture: u.Avatar.Data.URL, + + // To be deprecated + Slug: u.Alias, + AvatarURL: u.Avatar.Data.URL, + FullName: strings.TrimSpace(u.FirstName + " " + u.LastName), + ProviderId: u.ID, + } + + return data, nil +} diff --git a/auth_v2.169.0/internal/api/provider/figma.go b/auth_v2.169.0/internal/api/provider/figma.go new file mode 100644 index 0000000..ba812da --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/figma.go @@ -0,0 +1,95 @@ +package provider + +import ( + "context" + "strings" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +// Figma +// Reference: https://www.figma.com/developers/api#oauth2 + +const ( + defaultFigmaAuthBase = "www.figma.com" + defaultFigmaAPIBase = "api.figma.com" +) + +type figmaProvider struct { + *oauth2.Config + APIHost string +} + +type figmaUser struct { + ID string `json:"id"` + Email string `json:"email"` + Name string `json:"handle"` + AvatarURL string `json:"img_url"` +} + +// NewFigmaProvider creates a Figma account provider. +func NewFigmaProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + authHost := chooseHost(ext.URL, defaultFigmaAuthBase) + apiHost := chooseHost(ext.URL, defaultFigmaAPIBase) + + oauthScopes := []string{ + "files:read", + } + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + return &figmaProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: authHost + "/oauth", + TokenURL: authHost + "/api/oauth/token", + }, + RedirectURL: ext.RedirectURI, + Scopes: oauthScopes, + }, + APIHost: apiHost, + }, nil +} + +func (p figmaProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return p.Exchange(context.Background(), code) +} + +func (p figmaProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var u figmaUser + if err := makeRequest(ctx, tok, p.Config, p.APIHost+"/v1/me", &u); err != nil { + return nil, err + } + + data := &UserProvidedData{} + if u.Email != "" { + data.Emails = []Email{{ + Email: u.Email, + Verified: true, + Primary: true, + }} + } + + data.Metadata = &Claims{ + Issuer: p.APIHost, + Subject: u.ID, + Name: u.Name, + Email: u.Email, + EmailVerified: true, + + // To be deprecated + AvatarURL: u.AvatarURL, + FullName: u.Name, + ProviderId: u.ID, + } + return data, nil +} diff --git a/auth_v2.169.0/internal/api/provider/fly.go b/auth_v2.169.0/internal/api/provider/fly.go new file mode 100644 index 0000000..d933752 --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/fly.go @@ -0,0 +1,103 @@ +package provider + +import ( + "context" + "strings" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const ( + defaultFlyAPIBase = "api.fly.io" +) + +type flyProvider struct { + *oauth2.Config + APIPath string +} + +type flyUser struct { + ResourceOwnerID string `json:"resource_owner_id"` + UserID string `json:"user_id"` + UserName string `json:"user_name"` + Email string `json:"email"` + Organizations []struct { + ID string `json:"id"` + Role string `json:"role"` + } `json:"organizations"` + Scope []string `json:"scope"` + Application map[string]string `json:"application"` + ExpiresIn int `json:"expires_in"` + CreatedAt int `json:"created_at"` +} + +// NewFlyProvider creates a Fly oauth provider. +func NewFlyProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + authHost := chooseHost(ext.URL, defaultFlyAPIBase) + + // Fly only provides the "read" scope. + // https://fly.io/docs/reference/extensions_api/#single-sign-on-flow + oauthScopes := []string{ + "read", + } + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + return &flyProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: authHost + "/oauth/authorize", + TokenURL: authHost + "/oauth/token", + }, + RedirectURL: ext.RedirectURI, + Scopes: oauthScopes, + }, + APIPath: authHost, + }, nil +} + +func (p flyProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return p.Exchange(context.Background(), code) +} + +func (p flyProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var u flyUser + if err := makeRequest(ctx, tok, p.Config, p.APIPath+"/oauth/token/info", &u); err != nil { + return nil, err + } + + data := &UserProvidedData{} + if u.Email != "" { + data.Emails = []Email{{ + Email: u.Email, + Verified: true, + Primary: true, + }} + } + + data.Metadata = &Claims{ + Issuer: p.APIPath, + Subject: u.UserID, + FullName: u.UserName, + Email: u.Email, + EmailVerified: true, + ProviderId: u.UserID, + CustomClaims: map[string]interface{}{ + "resource_owner_id": u.ResourceOwnerID, + "organizations": u.Organizations, + "application": u.Application, + "scope": u.Scope, + "created_at": u.CreatedAt, + }, + } + return data, nil +} diff --git a/auth_v2.169.0/internal/api/provider/github.go b/auth_v2.169.0/internal/api/provider/github.go new file mode 100644 index 0000000..0da3e88 --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/github.go @@ -0,0 +1,110 @@ +package provider + +import ( + "context" + "strconv" + "strings" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +// Github + +const ( + defaultGitHubAuthBase = "github.com" + defaultGitHubAPIBase = "api.github.com" +) + +type githubProvider struct { + *oauth2.Config + APIHost string +} + +type githubUser struct { + ID int `json:"id"` + UserName string `json:"login"` + Email string `json:"email"` + Name string `json:"name"` + AvatarURL string `json:"avatar_url"` +} + +type githubUserEmail struct { + Email string `json:"email"` + Primary bool `json:"primary"` + Verified bool `json:"verified"` +} + +// NewGithubProvider creates a Github account provider. +func NewGithubProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + authHost := chooseHost(ext.URL, defaultGitHubAuthBase) + apiHost := chooseHost(ext.URL, defaultGitHubAPIBase) + if !strings.HasSuffix(apiHost, defaultGitHubAPIBase) { + apiHost += "/api/v3" + } + + oauthScopes := []string{ + "user:email", + } + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + return &githubProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: authHost + "/login/oauth/authorize", + TokenURL: authHost + "/login/oauth/access_token", + }, + RedirectURL: ext.RedirectURI, + Scopes: oauthScopes, + }, + APIHost: apiHost, + }, nil +} + +func (g githubProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +func (g githubProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var u githubUser + if err := makeRequest(ctx, tok, g.Config, g.APIHost+"/user", &u); err != nil { + return nil, err + } + + data := &UserProvidedData{ + Metadata: &Claims{ + Issuer: g.APIHost, + Subject: strconv.Itoa(u.ID), + Name: u.Name, + PreferredUsername: u.UserName, + + // To be deprecated + AvatarURL: u.AvatarURL, + FullName: u.Name, + ProviderId: strconv.Itoa(u.ID), + UserNameKey: u.UserName, + }, + } + + var emails []*githubUserEmail + if err := makeRequest(ctx, tok, g.Config, g.APIHost+"/user/emails", &emails); err != nil { + return nil, err + } + + for _, e := range emails { + if e.Email != "" { + data.Emails = append(data.Emails, Email{Email: e.Email, Verified: e.Verified, Primary: e.Primary}) + } + } + + return data, nil +} diff --git a/auth_v2.169.0/internal/api/provider/gitlab.go b/auth_v2.169.0/internal/api/provider/gitlab.go new file mode 100644 index 0000000..4b5d70c --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/gitlab.go @@ -0,0 +1,107 @@ +package provider + +import ( + "context" + "strconv" + "strings" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +// Gitlab + +const defaultGitLabAuthBase = "gitlab.com" + +type gitlabProvider struct { + *oauth2.Config + Host string +} + +type gitlabUser struct { + Email string `json:"email"` + Name string `json:"name"` + AvatarURL string `json:"avatar_url"` + ConfirmedAt string `json:"confirmed_at"` + ID int `json:"id"` +} + +type gitlabUserEmail struct { + ID int `json:"id"` + Email string `json:"email"` +} + +// NewGitlabProvider creates a Gitlab account provider. +func NewGitlabProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + oauthScopes := []string{ + "read_user", + } + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + host := chooseHost(ext.URL, defaultGitLabAuthBase) + return &gitlabProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: host + "/oauth/authorize", + TokenURL: host + "/oauth/token", + }, + RedirectURL: ext.RedirectURI, + Scopes: oauthScopes, + }, + Host: host, + }, nil +} + +func (g gitlabProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +func (g gitlabProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var u gitlabUser + + if err := makeRequest(ctx, tok, g.Config, g.Host+"/api/v4/user", &u); err != nil { + return nil, err + } + + data := &UserProvidedData{} + + var emails []*gitlabUserEmail + if err := makeRequest(ctx, tok, g.Config, g.Host+"/api/v4/user/emails", &emails); err != nil { + return nil, err + } + + for _, e := range emails { + // additional emails from GitLab don't return confirm status + if e.Email != "" { + data.Emails = append(data.Emails, Email{Email: e.Email, Verified: false, Primary: false}) + } + } + + if u.Email != "" { + verified := u.ConfirmedAt != "" + data.Emails = append(data.Emails, Email{Email: u.Email, Verified: verified, Primary: true}) + } + + data.Metadata = &Claims{ + Issuer: g.Host, + Subject: strconv.Itoa(u.ID), + Name: u.Name, + Picture: u.AvatarURL, + + // To be deprecated + AvatarURL: u.AvatarURL, + FullName: u.Name, + ProviderId: strconv.Itoa(u.ID), + } + + return data, nil +} diff --git a/auth_v2.169.0/internal/api/provider/google.go b/auth_v2.169.0/internal/api/provider/google.go new file mode 100644 index 0000000..03b76ae --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/google.go @@ -0,0 +1,144 @@ +package provider + +import ( + "context" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +type googleUser struct { + ID string `json:"id"` + Subject string `json:"sub"` + Issuer string `json:"iss"` + Name string `json:"name"` + AvatarURL string `json:"picture"` + Email string `json:"email"` + VerifiedEmail bool `json:"verified_email"` + EmailVerified bool `json:"email_verified"` + HostedDomain string `json:"hd"` +} + +func (u googleUser) IsEmailVerified() bool { + return u.VerifiedEmail || u.EmailVerified +} + +const IssuerGoogle = "https://accounts.google.com" + +var internalIssuerGoogle = IssuerGoogle + +type googleProvider struct { + *oauth2.Config + + oidc *oidc.Provider +} + +// NewGoogleProvider creates a Google OAuth2 identity provider. +func NewGoogleProvider(ctx context.Context, ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + if ext.URL != "" { + logrus.Warn("Google OAuth provider has URL config set which is ignored (check GOTRUE_EXTERNAL_GOOGLE_URL)") + } + + oauthScopes := []string{ + "email", + "profile", + } + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + oidcProvider, err := oidc.NewProvider(ctx, internalIssuerGoogle) + if err != nil { + return nil, err + } + + return &googleProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oidcProvider.Endpoint(), + Scopes: oauthScopes, + RedirectURL: ext.RedirectURI, + }, + oidc: oidcProvider, + }, nil +} + +func (g googleProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +const UserInfoEndpointGoogle = "https://www.googleapis.com/userinfo/v2/me" + +var internalUserInfoEndpointGoogle = UserInfoEndpointGoogle + +func (g googleProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + if idToken := tok.Extra("id_token"); idToken != nil { + _, data, err := ParseIDToken(ctx, g.oidc, &oidc.Config{ + ClientID: g.Config.ClientID, + }, idToken.(string), ParseIDTokenOptions{ + AccessToken: tok.AccessToken, + }) + if err != nil { + return nil, err + } + + return data, nil + } + + // This whole section offers legacy support in case the Google OAuth2 + // flow does not return an ID Token for the user, which appears to + // always be the case. + logrus.Info("Using Google OAuth2 user info endpoint, an ID token was not returned by Google") + + var u googleUser + if err := makeRequest(ctx, tok, g.Config, internalUserInfoEndpointGoogle, &u); err != nil { + return nil, err + } + + var data UserProvidedData + + if u.Email != "" { + data.Emails = append(data.Emails, Email{ + Email: u.Email, + Verified: u.IsEmailVerified(), + Primary: true, + }) + } + + data.Metadata = &Claims{ + Issuer: internalUserInfoEndpointGoogle, + Subject: u.ID, + Name: u.Name, + Picture: u.AvatarURL, + Email: u.Email, + EmailVerified: u.IsEmailVerified(), + + // To be deprecated + AvatarURL: u.AvatarURL, + FullName: u.Name, + ProviderId: u.ID, + } + + return &data, nil +} + +// ResetGoogleProvider should only be used in tests! +func ResetGoogleProvider() { + internalIssuerGoogle = IssuerGoogle + internalUserInfoEndpointGoogle = UserInfoEndpointGoogle +} + +// OverrideGoogleProvider should only be used in tests! +func OverrideGoogleProvider(issuer, userInfo string) { + internalIssuerGoogle = issuer + internalUserInfoEndpointGoogle = userInfo +} diff --git a/auth_v2.169.0/internal/api/provider/kakao.go b/auth_v2.169.0/internal/api/provider/kakao.go new file mode 100644 index 0000000..2482b97 --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/kakao.go @@ -0,0 +1,107 @@ +package provider + +import ( + "context" + "strconv" + "strings" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const ( + defaultKakaoAuthBase = "kauth.kakao.com" + defaultKakaoAPIBase = "kapi.kakao.com" + IssuerKakao = "https://kauth.kakao.com" +) + +type kakaoProvider struct { + *oauth2.Config + APIHost string +} + +type kakaoUser struct { + ID int `json:"id"` + Account struct { + Profile struct { + Name string `json:"nickname"` + ProfileImageURL string `json:"profile_image_url"` + } `json:"profile"` + Email string `json:"email"` + EmailValid bool `json:"is_email_valid"` + EmailVerified bool `json:"is_email_verified"` + } `json:"kakao_account"` +} + +func (p kakaoProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return p.Exchange(context.Background(), code) +} + +func (p kakaoProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var u kakaoUser + + if err := makeRequest(ctx, tok, p.Config, p.APIHost+"/v2/user/me", &u); err != nil { + return nil, err + } + + data := &UserProvidedData{} + + if u.Account.Email != "" { + data.Emails = []Email{ + { + Email: u.Account.Email, + Verified: u.Account.EmailVerified && u.Account.EmailValid, + Primary: true, + }, + } + } + + data.Metadata = &Claims{ + Issuer: p.APIHost, + Subject: strconv.Itoa(u.ID), + + Name: u.Account.Profile.Name, + PreferredUsername: u.Account.Profile.Name, + + // To be deprecated + AvatarURL: u.Account.Profile.ProfileImageURL, + FullName: u.Account.Profile.Name, + ProviderId: strconv.Itoa(u.ID), + UserNameKey: u.Account.Profile.Name, + } + return data, nil +} + +func NewKakaoProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + authHost := chooseHost(ext.URL, defaultKakaoAuthBase) + apiHost := chooseHost(ext.URL, defaultKakaoAPIBase) + + oauthScopes := []string{ + "account_email", + "profile_image", + "profile_nickname", + } + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + return &kakaoProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthStyle: oauth2.AuthStyleInParams, + AuthURL: authHost + "/oauth/authorize", + TokenURL: authHost + "/oauth/token", + }, + RedirectURL: ext.RedirectURI, + Scopes: oauthScopes, + }, + APIHost: apiHost, + }, nil +} diff --git a/auth_v2.169.0/internal/api/provider/keycloak.go b/auth_v2.169.0/internal/api/provider/keycloak.go new file mode 100644 index 0000000..39ccec5 --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/keycloak.go @@ -0,0 +1,98 @@ +package provider + +import ( + "context" + "errors" + "strings" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +// Keycloak +type keycloakProvider struct { + *oauth2.Config + Host string +} + +type keycloakUser struct { + Name string `json:"name"` + Sub string `json:"sub"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` +} + +// NewKeycloakProvider creates a Keycloak account provider. +func NewKeycloakProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + oauthScopes := []string{ + "profile", + "email", + } + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + if ext.URL == "" { + return nil, errors.New("unable to find URL for the Keycloak provider") + } + + extURLlen := len(ext.URL) + if ext.URL[extURLlen-1] == '/' { + ext.URL = ext.URL[:extURLlen-1] + } + + return &keycloakProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: ext.URL + "/protocol/openid-connect/auth", + TokenURL: ext.URL + "/protocol/openid-connect/token", + }, + RedirectURL: ext.RedirectURI, + Scopes: oauthScopes, + }, + Host: ext.URL, + }, nil +} + +func (g keycloakProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +func (g keycloakProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var u keycloakUser + + if err := makeRequest(ctx, tok, g.Config, g.Host+"/protocol/openid-connect/userinfo", &u); err != nil { + return nil, err + } + + data := &UserProvidedData{} + if u.Email != "" { + data.Emails = []Email{{ + Email: u.Email, + Verified: u.EmailVerified, + Primary: true, + }} + } + + data.Metadata = &Claims{ + Issuer: g.Host, + Subject: u.Sub, + Name: u.Name, + Email: u.Email, + EmailVerified: u.EmailVerified, + + // To be deprecated + FullName: u.Name, + ProviderId: u.Sub, + } + + return data, nil + +} diff --git a/auth_v2.169.0/internal/api/provider/linkedin.go b/auth_v2.169.0/internal/api/provider/linkedin.go new file mode 100644 index 0000000..bc33515 --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/linkedin.go @@ -0,0 +1,149 @@ +package provider + +import ( + "context" + "strings" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const ( + defaultLinkedinAPIBase = "api.linkedin.com" +) + +type linkedinProvider struct { + *oauth2.Config + APIPath string + UserInfoURL string + UserEmailUrl string +} + +// See https://docs.microsoft.com/en-us/linkedin/consumer/integrations/self-serve/sign-in-with-linkedin?context=linkedin/consumer/context +// for retrieving a member's profile. This requires the r_liteprofile scope. +type linkedinUser struct { + ID string `json:"id"` + FirstName linkedinName `json:"firstName"` + LastName linkedinName `json:"lastName"` + AvatarURL struct { + DisplayImage struct { + Elements []struct { + Identifiers []struct { + Identifier string `json:"identifier"` + } `json:"identifiers"` + } `json:"elements"` + } `json:"displayImage~"` + } `json:"profilePicture"` +} + +func (u *linkedinUser) getAvatarUrl() string { + avatarURL := "" + if len(u.AvatarURL.DisplayImage.Elements) > 0 { + avatarURL = u.AvatarURL.DisplayImage.Elements[0].Identifiers[0].Identifier + } + return avatarURL +} + +type linkedinName struct { + Localized interface{} `json:"localized"` + PreferredLocale linkedinLocale `json:"preferredLocale"` +} + +type linkedinLocale struct { + Country string `json:"country"` + Language string `json:"language"` +} + +// See https://docs.microsoft.com/en-us/linkedin/consumer/integrations/self-serve/sign-in-with-linkedin?context=linkedin/consumer/context#retrieving-member-email-address +// for retrieving a member email address. This requires the r_email_address scope. +type linkedinElements struct { + Elements []struct { + Handle string `json:"handle"` + HandleTilde struct { + EmailAddress string `json:"emailAddress"` + } `json:"handle~"` + } `json:"elements"` +} + +// NewLinkedinProvider creates a Linkedin account provider. +func NewLinkedinProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + apiPath := chooseHost(ext.URL, defaultLinkedinAPIBase) + + oauthScopes := []string{ + "r_emailaddress", + "r_liteprofile", + } + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + return &linkedinProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: apiPath + "/oauth/v2/authorization", + TokenURL: apiPath + "/oauth/v2/accessToken", + }, + Scopes: oauthScopes, + RedirectURL: ext.RedirectURI, + }, + APIPath: apiPath, + }, nil +} + +func (g linkedinProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +func GetName(name linkedinName) string { + key := name.PreferredLocale.Language + "_" + name.PreferredLocale.Country + myMap := name.Localized.(map[string]interface{}) + return myMap[key].(string) +} + +func (g linkedinProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var u linkedinUser + if err := makeRequest(ctx, tok, g.Config, g.APIPath+"/v2/me?projection=(id,firstName,lastName,profilePicture(displayImage~:playableStreams))", &u); err != nil { + return nil, err + } + + var e linkedinElements + // Note: Use primary contact api for handling phone numbers + if err := makeRequest(ctx, tok, g.Config, g.APIPath+"/v2/emailAddress?q=members&projection=(elements*(handle~))", &e); err != nil { + return nil, err + } + + data := &UserProvidedData{} + + if e.Elements[0].HandleTilde.EmailAddress != "" { + // linkedin only returns the primary email which is verified for the r_emailaddress scope. + data.Emails = []Email{{ + Email: e.Elements[0].HandleTilde.EmailAddress, + Primary: true, + Verified: true, + }} + } + + avatarURL := u.getAvatarUrl() + + data.Metadata = &Claims{ + Issuer: g.APIPath, + Subject: u.ID, + Name: strings.TrimSpace(GetName(u.FirstName) + " " + GetName(u.LastName)), + Picture: avatarURL, + Email: e.Elements[0].HandleTilde.EmailAddress, + EmailVerified: true, + + // To be deprecated + AvatarURL: avatarURL, + FullName: strings.TrimSpace(GetName(u.FirstName) + " " + GetName(u.LastName)), + ProviderId: u.ID, + } + return data, nil +} diff --git a/auth_v2.169.0/internal/api/provider/linkedin_oidc.go b/auth_v2.169.0/internal/api/provider/linkedin_oidc.go new file mode 100644 index 0000000..a5d94fa --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/linkedin_oidc.go @@ -0,0 +1,81 @@ +package provider + +import ( + "context" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const ( + defaultLinkedinOIDCAPIBase = "api.linkedin.com" + IssuerLinkedin = "https://www.linkedin.com/oauth" +) + +type linkedinOIDCProvider struct { + *oauth2.Config + oidc *oidc.Provider + APIPath string +} + +// NewLinkedinOIDCProvider creates a Linkedin account provider via OIDC. +func NewLinkedinOIDCProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + apiPath := chooseHost(ext.URL, defaultLinkedinOIDCAPIBase) + + oauthScopes := []string{ + "openid", + "email", + "profile", + } + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + oidcProvider, err := oidc.NewProvider(context.Background(), IssuerLinkedin) + if err != nil { + return nil, err + } + + return &linkedinOIDCProvider{ + oidc: oidcProvider, + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: apiPath + "/oauth/v2/authorization", + TokenURL: apiPath + "/oauth/v2/accessToken", + }, + Scopes: oauthScopes, + RedirectURL: ext.RedirectURI, + }, + APIPath: apiPath, + }, nil +} + +func (g linkedinOIDCProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +func (g linkedinOIDCProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + idToken := tok.Extra("id_token") + if tok.AccessToken == "" || idToken == nil { + return &UserProvidedData{}, nil + } + + _, data, err := ParseIDToken(ctx, g.oidc, &oidc.Config{ + ClientID: g.ClientID, + }, idToken.(string), ParseIDTokenOptions{ + AccessToken: tok.AccessToken, + }) + if err != nil { + return nil, err + } + return data, nil +} diff --git a/auth_v2.169.0/internal/api/provider/notion.go b/auth_v2.169.0/internal/api/provider/notion.go new file mode 100644 index 0000000..f8d0ee7 --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/notion.go @@ -0,0 +1,121 @@ +package provider + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/utilities" + "golang.org/x/oauth2" +) + +const ( + defaultNotionApiBase = "api.notion.com" + notionApiVersion = "2021-08-16" +) + +type notionProvider struct { + *oauth2.Config + APIPath string +} + +type notionUser struct { + Bot struct { + Owner struct { + User struct { + ID string `json:"id"` + Name string `json:"name"` + AvatarURL string `json:"avatar_url"` + Person struct { + Email string `json:"email"` + } `json:"person"` + } `json:"user"` + } `json:"owner"` + } `json:"bot"` +} + +// NewNotionProvider creates a Notion account provider. +func NewNotionProvider(ext conf.OAuthProviderConfiguration) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + authHost := chooseHost(ext.URL, defaultNotionApiBase) + + return ¬ionProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: authHost + "/v1/oauth/authorize", + TokenURL: authHost + "/v1/oauth/token", + }, + RedirectURL: ext.RedirectURI, + }, + APIPath: authHost, + }, nil +} + +func (g notionProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +func (g notionProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var u notionUser + + // Perform http request, because we need to set the Notion-Version header + req, err := http.NewRequest("GET", g.APIPath+"/v1/users/me", nil) + + if err != nil { + return nil, err + } + + // set headers + req.Header.Set("Notion-Version", notionApiVersion) + req.Header.Set("Authorization", "Bearer "+tok.AccessToken) + + client := &http.Client{Timeout: defaultTimeout} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer utilities.SafeClose(resp.Body) + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("a %v error occurred with retrieving user from notion", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + err = json.Unmarshal(body, &u) + if err != nil { + return nil, err + } + + data := &UserProvidedData{} + if u.Bot.Owner.User.Person.Email != "" { + data.Emails = []Email{{ + Email: u.Bot.Owner.User.Person.Email, + Verified: true, // Notion dosen't provide data on if email is verified. + Primary: true, + }} + } + + data.Metadata = &Claims{ + Issuer: g.APIPath, + Subject: u.Bot.Owner.User.ID, + Name: u.Bot.Owner.User.Name, + Picture: u.Bot.Owner.User.AvatarURL, + + // To be deprecated + AvatarURL: u.Bot.Owner.User.AvatarURL, + FullName: u.Bot.Owner.User.Name, + ProviderId: u.Bot.Owner.User.ID, + } + return data, nil +} diff --git a/auth_v2.169.0/internal/api/provider/oidc.go b/auth_v2.169.0/internal/api/provider/oidc.go new file mode 100644 index 0000000..51c88e6 --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/oidc.go @@ -0,0 +1,410 @@ +package provider + +import ( + "context" + "fmt" + "strconv" + "strings" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/golang-jwt/jwt/v5" +) + +type ParseIDTokenOptions struct { + SkipAccessTokenCheck bool + AccessToken string +} + +// OverrideVerifiers can be used to set a custom verifier for an OIDC provider +// (identified by the provider's Endpoint().AuthURL string). Should only be +// used in tests. +var OverrideVerifiers = make(map[string]func(context.Context, *oidc.Config) *oidc.IDTokenVerifier) + +// OverrideClock can be used to set a custom clock function to be used when +// parsing ID tokens. Should only be used in tests. +var OverrideClock func() time.Time + +func ParseIDToken(ctx context.Context, provider *oidc.Provider, config *oidc.Config, idToken string, options ParseIDTokenOptions) (*oidc.IDToken, *UserProvidedData, error) { + if config == nil { + config = &oidc.Config{ + // aud claim check to be performed by other flows + SkipClientIDCheck: true, + } + } + + if OverrideClock != nil { + clonedConfig := *config + clonedConfig.Now = OverrideClock + config = &clonedConfig + } + + verifier := provider.VerifierContext(ctx, config) + overrideVerifier, ok := OverrideVerifiers[provider.Endpoint().AuthURL] + if ok && overrideVerifier != nil { + verifier = overrideVerifier(ctx, config) + } + + token, err := verifier.Verify(ctx, idToken) + if err != nil { + return nil, nil, err + } + + var data *UserProvidedData + + switch token.Issuer { + case IssuerGoogle: + token, data, err = parseGoogleIDToken(token) + case IssuerApple: + token, data, err = parseAppleIDToken(token) + case IssuerLinkedin: + token, data, err = parseLinkedinIDToken(token) + case IssuerKakao: + token, data, err = parseKakaoIDToken(token) + case IssuerVercelMarketplace: + token, data, err = parseVercelMarketplaceIDToken(token) + default: + if IsAzureIssuer(token.Issuer) { + token, data, err = parseAzureIDToken(token) + } else { + token, data, err = parseGenericIDToken(token) + } + } + + if err != nil { + return nil, nil, err + } + + if !options.SkipAccessTokenCheck && token.AccessTokenHash != "" { + if err := token.VerifyAccessToken(options.AccessToken); err != nil { + return nil, nil, err + } + } + + return token, data, nil +} + +func parseGoogleIDToken(token *oidc.IDToken) (*oidc.IDToken, *UserProvidedData, error) { + var claims googleUser + if err := token.Claims(&claims); err != nil { + return nil, nil, err + } + + var data UserProvidedData + + if claims.Email != "" { + data.Emails = append(data.Emails, Email{ + Email: claims.Email, + Verified: claims.IsEmailVerified(), + Primary: true, + }) + } + + data.Metadata = &Claims{ + Issuer: claims.Issuer, + Subject: claims.Subject, + Name: claims.Name, + Picture: claims.AvatarURL, + + // To be deprecated + AvatarURL: claims.AvatarURL, + FullName: claims.Name, + ProviderId: claims.Subject, + } + + if claims.HostedDomain != "" { + data.Metadata.CustomClaims = map[string]any{ + "hd": claims.HostedDomain, + } + } + + return token, &data, nil +} + +type AppleIDTokenClaims struct { + jwt.RegisteredClaims + + Email string `json:"email"` + + AuthTime *float64 `json:"auth_time"` + IsPrivateEmail *IsPrivateEmail `json:"is_private_email"` +} + +func parseAppleIDToken(token *oidc.IDToken) (*oidc.IDToken, *UserProvidedData, error) { + var claims AppleIDTokenClaims + if err := token.Claims(&claims); err != nil { + return nil, nil, err + } + + var data UserProvidedData + + data.Emails = append(data.Emails, Email{ + Email: claims.Email, + Verified: true, + Primary: true, + }) + + data.Metadata = &Claims{ + Issuer: token.Issuer, + Subject: token.Subject, + ProviderId: token.Subject, + CustomClaims: make(map[string]any), + } + + if claims.IsPrivateEmail != nil { + data.Metadata.CustomClaims["is_private_email"] = *claims.IsPrivateEmail + } + + if claims.AuthTime != nil { + data.Metadata.CustomClaims["auth_time"] = *claims.AuthTime + } + + if len(data.Metadata.CustomClaims) < 1 { + data.Metadata.CustomClaims = nil + } + + return token, &data, nil +} + +type LinkedinIDTokenClaims struct { + jwt.RegisteredClaims + + Email string `json:"email"` + EmailVerified string `json:"email_verified"` + FamilyName string `json:"family_name"` + GivenName string `json:"given_name"` + Locale string `json:"locale"` + Picture string `json:"picture"` +} + +func parseLinkedinIDToken(token *oidc.IDToken) (*oidc.IDToken, *UserProvidedData, error) { + var claims LinkedinIDTokenClaims + if err := token.Claims(&claims); err != nil { + return nil, nil, err + } + + var data UserProvidedData + emailVerified, err := strconv.ParseBool(claims.EmailVerified) + if err != nil { + return nil, nil, err + } + + if claims.Email != "" { + data.Emails = append(data.Emails, Email{ + Email: claims.Email, + Verified: emailVerified, + Primary: true, + }) + } + + data.Metadata = &Claims{ + Issuer: token.Issuer, + Subject: token.Subject, + Name: strings.TrimSpace(claims.GivenName + " " + claims.FamilyName), + GivenName: claims.GivenName, + FamilyName: claims.FamilyName, + Locale: claims.Locale, + Picture: claims.Picture, + ProviderId: token.Subject, + } + + return token, &data, nil +} + +type AzureIDTokenClaims struct { + jwt.RegisteredClaims + + Email string `json:"email"` + Name string `json:"name"` + PreferredUsername string `json:"preferred_username"` + XMicrosoftEmailDomainOwnerVerified any `json:"xms_edov"` +} + +func (c *AzureIDTokenClaims) IsEmailVerified() bool { + emailVerified := false + + edov := c.XMicrosoftEmailDomainOwnerVerified + + // If xms_edov is not set, and an email is present or xms_edov is true, + // only then is the email regarded as verified. + // https://learn.microsoft.com/en-us/azure/active-directory/develop/migrate-off-email-claim-authorization#using-the-xms_edov-optional-claim-to-determine-email-verification-status-and-migrate-users + if edov == nil { + // An email is provided, but xms_edov is not -- probably not + // configured, so we must assume the email is verified as Azure + // will only send out a potentially unverified email address in + // single-tenanat apps. + emailVerified = c.Email != "" + } else { + edovBool := false + + // Azure can't be trusted with how they encode the xms_edov + // claim. Sometimes it's "xms_edov": "1", sometimes "xms_edov": true. + switch v := edov.(type) { + case bool: + edovBool = v + + case string: + edovBool = v == "1" || v == "true" + + default: + edovBool = false + } + + emailVerified = c.Email != "" && edovBool + } + + return emailVerified +} + +// removeAzureClaimsFromCustomClaims contains the list of claims to be removed +// from the CustomClaims map. See: +// https://learn.microsoft.com/en-us/azure/active-directory/develop/id-token-claims-reference +var removeAzureClaimsFromCustomClaims = []string{ + "aud", + "iss", + "iat", + "nbf", + "exp", + "c_hash", + "at_hash", + "aio", + "nonce", + "rh", + "uti", + "jti", + "ver", + "sub", + "name", + "preferred_username", +} + +func parseAzureIDToken(token *oidc.IDToken) (*oidc.IDToken, *UserProvidedData, error) { + var data UserProvidedData + + var azureClaims AzureIDTokenClaims + if err := token.Claims(&azureClaims); err != nil { + return nil, nil, err + } + + data.Metadata = &Claims{ + Issuer: token.Issuer, + Subject: token.Subject, + ProviderId: token.Subject, + PreferredUsername: azureClaims.PreferredUsername, + FullName: azureClaims.Name, + CustomClaims: make(map[string]any), + } + + if azureClaims.Email != "" { + data.Emails = []Email{{ + Email: azureClaims.Email, + Verified: azureClaims.IsEmailVerified(), + Primary: true, + }} + } + + if err := token.Claims(&data.Metadata.CustomClaims); err != nil { + return nil, nil, err + } + + if data.Metadata.CustomClaims != nil { + for _, claim := range removeAzureClaimsFromCustomClaims { + delete(data.Metadata.CustomClaims, claim) + } + } + + return token, &data, nil +} + +type KakaoIDTokenClaims struct { + jwt.RegisteredClaims + + Email string `json:"email"` + Nickname string `json:"nickname"` + Picture string `json:"picture"` +} + +func parseKakaoIDToken(token *oidc.IDToken) (*oidc.IDToken, *UserProvidedData, error) { + var claims KakaoIDTokenClaims + + if err := token.Claims(&claims); err != nil { + return nil, nil, err + } + + var data UserProvidedData + + if claims.Email != "" { + data.Emails = append(data.Emails, Email{ + Email: claims.Email, + Verified: true, + Primary: true, + }) + } + + data.Metadata = &Claims{ + Issuer: token.Issuer, + Subject: token.Subject, + Name: claims.Nickname, + PreferredUsername: claims.Nickname, + ProviderId: token.Subject, + Picture: claims.Picture, + } + + return token, &data, nil +} + +type VercelMarketplaceIDTokenClaims struct { + jwt.RegisteredClaims + + UserEmail string `json:"user_email"` + UserName string `json:"user_name"` + UserAvatarUrl string `json:"user_avatar_url"` +} + +func parseVercelMarketplaceIDToken(token *oidc.IDToken) (*oidc.IDToken, *UserProvidedData, error) { + var claims VercelMarketplaceIDTokenClaims + + if err := token.Claims(&claims); err != nil { + return nil, nil, err + } + + var data UserProvidedData + + data.Emails = append(data.Emails, Email{ + Email: claims.UserEmail, + Verified: true, + Primary: true, + }) + + data.Metadata = &Claims{ + Issuer: token.Issuer, + Subject: token.Subject, + ProviderId: token.Subject, + Name: claims.UserName, + Picture: claims.UserAvatarUrl, + } + + return token, &data, nil +} + +func parseGenericIDToken(token *oidc.IDToken) (*oidc.IDToken, *UserProvidedData, error) { + var data UserProvidedData + + if err := token.Claims(&data.Metadata); err != nil { + return nil, nil, err + } + + if data.Metadata.Email != "" { + data.Emails = append(data.Emails, Email{ + Email: data.Metadata.Email, + Verified: data.Metadata.EmailVerified, + Primary: true, + }) + } + + if len(data.Emails) <= 0 { + return nil, nil, fmt.Errorf("provider: Generic OIDC ID token from issuer %q must contain an email address", token.Issuer) + } + + return token, &data, nil +} diff --git a/auth_v2.169.0/internal/api/provider/oidc_test.go b/auth_v2.169.0/internal/api/provider/oidc_test.go new file mode 100644 index 0000000..e088cd4 --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/oidc_test.go @@ -0,0 +1,185 @@ +package provider + +import ( + "context" + "crypto" + "crypto/rsa" + "encoding/base64" + "math/big" + "testing" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/stretchr/testify/require" +) + +type realIDToken struct { + AccessToken string + IDToken string + Time time.Time + Email string + Verifier func(context.Context, *oidc.Config) *oidc.IDTokenVerifier +} + +func googleIDTokenVerifier(ctx context.Context, config *oidc.Config) *oidc.IDTokenVerifier { + keyBytes, err := base64.RawURLEncoding.DecodeString("pP-rCe4jkKX6mq8yP1GcBZcxJzmxKWicHHor1S3Q49u6Oe-bQsk5NsK5mdR7Y7liGV9n0ikXSM42dYKQdxbhKA-7--fFon5isJoHr4fIwL2CCwVm5QWlK37q6PiH2_F1M0hRorHfkCb4nI56ZvfygvuOH4LIS82OzIgmsYbeEfwDRpeMSxWKwlpa3pX3GZ6jG7FgzJGBvmBkagpgsa2JZdyU4gEGMOkHdSzi5Ii-6RGfFLhhI1OMxC9P2JaU5yjMN2pikfFIq_dbpm75yNUGpWJNVywtrlNvvJfA74UMN_lVCAaSR0A03BUMg6ljB65gFllpKF224uWBA8tpjngwKQ") + if err != nil { + panic(err) + } + + n := big.NewInt(0) + n.SetBytes(keyBytes) + + publicKey := &rsa.PublicKey{ + N: n, + E: 65537, + } + + return oidc.NewVerifier( + "https://accounts.google.com", + &oidc.StaticKeySet{ + PublicKeys: []crypto.PublicKey{publicKey}, + }, + config, + ) +} + +func azureIDTokenVerifier(ctx context.Context, config *oidc.Config) *oidc.IDTokenVerifier { + keyBytes, err := base64.RawURLEncoding.DecodeString("1djHqyNclRpJWtHCnkP5QWvDxozCTG_ZDnkEmudpcxjnYrVL4RVIwdNCBLAStg8Dob5OUyAlHcRFMCqGTW4HA6kHgIxyfiFsYCBDMHWd2-61N1cAS6S9SdXlWXkBQgU0Qj6q_yFYTRS7J-zI_jMLRQAlpowfDFM1vSTBIci7kqynV6pPOz4jMaDQevmSscEs-jz7e8YXAiiVpN588oBQ0jzQaTTx90WjgRP23mn8mPyabj8gcR3gLwKLsBUhlp1oZj7FopGp8z8LHuueJB_q_LOUa_gAozZ0lfoJxFimXgpgEK7GNVdMRsMH3mIl0A5oYN8f29RFwbG0rNO5ZQ1YWQ") + if err != nil { + panic(err) + } + + n := big.NewInt(0) + n.SetBytes(keyBytes) + + publicKey := &rsa.PublicKey{ + N: n, + E: 65537, + } + + return oidc.NewVerifier( + IssuerAzureMicrosoft, + &oidc.StaticKeySet{ + PublicKeys: []crypto.PublicKey{publicKey}, + }, + config, + ) +} + +var realIDTokens map[string]realIDToken = map[string]realIDToken{ + IssuerGoogle: { + AccessToken: "ya29.a0AWY7CklOn4TehiT4kA6osNP6e-pHErOY8X53T2oUe7Oqqwc3-uIJpoEgoZCUogewBuNWr-JFT2FK9s0E0oRSFtAfu0-uIDckBj5ca1pxnk0-zPkPZouqoIyl0AlIpQjIUEuyuQTYUay99kRajbHcFCR1VMbNcQaCgYKAQESARESFQG1tDrp1joUHupV5Rn8-nWDpKkmMw0165", + IDToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6Ijg1YmE5MzEzZmQ3YTdkNGFmYTg0ODg0YWJjYzg0MDMwMDQzNjMxODAiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI5MTQ2NjY0MjA3NS03OWNwaWs4aWNxYzU4NjY5bjdtaXY5NjZsYmFwOTNhMi5hcHBzLmdvb2dsZXVzZXJjb250ZW50LmNvbSIsImF1ZCI6IjkxNDY2NjQyMDc1LTc5Y3BpazhpY3FjNTg2NjluN21pdjk2NmxiYXA5M2EyLmFwcHMuZ29vZ2xldXNlcmNvbnRlbnQuY29tIiwic3ViIjoiMTAzNzgzMTkwMTI2NDM5NzUxMjY5IiwiaGQiOiJzdXBhYmFzZS5pbyIsImVtYWlsIjoic3RvamFuQHN1cGFiYXNlLmlvIiwiZW1haWxfdmVyaWZpZWQiOnRydWUsImF0X2hhc2giOiJlcGVWV244VmxWa28zd195Unk3UDZRIiwibmFtZSI6IlN0b2phbiBEaW1pdHJvdnNraSIsInBpY3R1cmUiOiJodHRwczovL2xoMy5nb29nbGV1c2VyY29udGVudC5jb20vYS9BQWNIVHRka0dhWjVlcGtqT1dxSEF1UUV4N2cwRlBCeXJiQ2ZNUjVNTk5kYz1zOTYtYyIsImdpdmVuX25hbWUiOiJTdG9qYW4iLCJmYW1pbHlfbmFtZSI6IkRpbWl0cm92c2tpIiwibG9jYWxlIjoiZW4tR0IiLCJpYXQiOjE2ODY2NTk5MzIsImV4cCI6MTY4NjY2MzUzMn0.nKAN9BFSxvavXYfWX4fZHREYY_3O4uOFRFq1KU1NNrBOMq_CPpM8c8PV7ZhKQvGCjBthSjtxGWbcqT0ByA7RdpNW6kj5UpFxEPdhenZ-eO1FwiEVIC8uZpiX6J3Nr7fAqi1P0DVeB3Zr_GrtkS9MDhZNb3hE5NDkvjCulwP4gRBC-5Pn_aRJRESxYkr_naKiSSmVilkmNVjZO4orq6KuYlvWHKHZIRiUI1akt0gVr5GxsEpd_duzUU30yVSPiq8l6fgxvJn2hT0MHa77wo3hvlP0NyAoSE7Nh4tRSowB0Qq7_byDMUmNWfXh-Qqa2M6ywuJ-_3LTLNUJH-cwdm2tNQ", + Time: time.Unix(1686659933, 0), // 1 sec after iat + Verifier: googleIDTokenVerifier, + }, + IssuerAzureMicrosoft: { + AccessToken: "access-token", + Time: time.Unix(1697277774, 0), // 1 sec after iat + IDToken: "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6IlhvdVhMWVExVGlwNW9kWWFqaUN0RlZnVmFFcyJ9.eyJ2ZXIiOiIyLjAiLCJpc3MiOiJodHRwczovL2xvZ2luLm1pY3Jvc29mdG9ubGluZS5jb20vOTE4ODA0MGQtNmM2Ny00YzViLWIxMTItMzZhMzA0YjY2ZGFkL3YyLjAiLCJzdWIiOiJBQUFBQUFBQUFBQUFBQUFBQUFBQUFCWkRuRDkxOTBfc2wxcTZwenZlRHZNIiwiYXVkIjoiYTBkOGY5NzItNTRhYy00YWJmLTkxNGMtNTIyMDE0YzQwMjJhIiwiZXhwIjoxNjk3MzY0NDczLCJpYXQiOjE2OTcyNzc3NzMsIm5iZiI6MTY5NzI3Nzc3MywiZW1haWwiOiJzZGltaXRyb3Zza2lAZ21haWwuY29tIiwidGlkIjoiOTE4ODA0MGQtNmM2Ny00YzViLWIxMTItMzZhMzA0YjY2ZGFkIiwieG1zX2Vkb3YiOiIxIiwiYWlvIjoiRHBQV3lZSnRJcUl5OHpyVjROIUlIdGtFa09BMDhPS29lZ1RkYmZQUEVPYmxtYk9ESFQ0cGJVcVI1cExraENyWWZ6bUgzb3A1RzN5RGp2M0tNZ0Rad29lQ1FjKmVueldyb21iQ3BuKkR6OEpQOGMxU3pEVG1TbGp4U3U3UnVLTXNZSjRvS1lDazFBSVcqUUNUTmlMWkpUKlN3WWZQcjZBTW9IejFEZ3pBZEFkbk9uWiFHNUNFeEtQalBxcHRuVmpUZlEkJCJ9.CskICxOaeqd4SkiPdWEHJKZVdhAdgzM5SN7K7FYi0dguQH1-v6XTetDIoEsBn0GZoozXjbG2GgkFcVhhBvNA0ZrDIr4KcjfnJ5-7rwX3AtxdQ3umrHRlGu3jlmbDOtWzPWNMLLRXfR1Mm3pHEUvlzqmk3Ffh4TuAmXID-fb-Xmfuuv1k0UsZ5mlr_3ybTPVZk-Lj0bqkR1L5Zzt4HjgfpchRryJ3Y24b4dDsSjg7mgE_5JivgjhtVef5OnqYhKUF1DTy2pFysFO_eRliK6qjouYeZnQOJnWHP1MgpySAOQ3sVcwvE4P9g7V3QouxByZPv-g99N1K4GwZrtdm46gtTQ", + Verifier: azureIDTokenVerifier, + }, +} + +func TestParseIDToken(t *testing.T) { + defer func() { + OverrideVerifiers = make(map[string]func(context.Context, *oidc.Config) *oidc.IDTokenVerifier) + OverrideClock = nil + }() + + // note that this test can fail if/when the issuers rotate their + // signing keys (which happens rarely if ever) + // then you should obtain new ID tokens and update this test + for issuer, token := range realIDTokens { + oidcProvider, err := oidc.NewProvider(context.Background(), issuer) + require.NoError(t, err) + + OverrideVerifiers[oidcProvider.Endpoint().AuthURL] = token.Verifier + + _, user, err := ParseIDToken(context.Background(), oidcProvider, &oidc.Config{ + SkipClientIDCheck: true, + Now: func() time.Time { + return token.Time + }, + }, token.IDToken, ParseIDTokenOptions{ + AccessToken: token.AccessToken, + }) + require.NoError(t, err) + + require.NotEmpty(t, user.Emails[0].Email) + require.Equal(t, user.Emails[0].Verified, true) + } +} + +func TestAzureIDTokenClaimsIsEmailVerified(t *testing.T) { + positiveExamples := []AzureIDTokenClaims{ + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: nil, + }, + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: true, + }, + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: "1", + }, + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: "true", + }, + } + + negativeExamples := []AzureIDTokenClaims{ + { + Email: "", + XMicrosoftEmailDomainOwnerVerified: true, + }, + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: false, + }, + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: "0", + }, + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: "false", + }, + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: float32(0), + }, + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: float64(0), + }, + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: int(0), + }, + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: int32(0), + }, + { + Email: "test@example.com", + XMicrosoftEmailDomainOwnerVerified: int64(0), + }, + } + + for i, example := range positiveExamples { + if !example.IsEmailVerified() { + t.Errorf("positive example %v reports negative result", i) + } + } + + for i, example := range negativeExamples { + if example.IsEmailVerified() { + t.Errorf("negative example %v reports positive result", i) + } + } +} diff --git a/auth_v2.169.0/internal/api/provider/provider.go b/auth_v2.169.0/internal/api/provider/provider.go new file mode 100644 index 0000000..857e882 --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/provider.go @@ -0,0 +1,128 @@ +package provider + +import ( + "bytes" + "context" + "encoding/json" + "io" + "log" + "net/http" + "os" + "time" + + "github.com/supabase/auth/internal/utilities" + "golang.org/x/oauth2" +) + +var defaultTimeout time.Duration = time.Second * 10 + +func init() { + timeoutStr := os.Getenv("GOTRUE_INTERNAL_HTTP_TIMEOUT") + if timeoutStr != "" { + if timeout, err := time.ParseDuration(timeoutStr); err != nil { + log.Fatalf("error loading GOTRUE_INTERNAL_HTTP_TIMEOUT: %v", err.Error()) + } else if timeout != 0 { + defaultTimeout = timeout + } + } +} + +type Claims struct { + // Reserved claims + Issuer string `json:"iss,omitempty" structs:"iss,omitempty"` + Subject string `json:"sub,omitempty" structs:"sub,omitempty"` + Aud string `json:"aud,omitempty" structs:"aud,omitempty"` + Iat float64 `json:"iat,omitempty" structs:"iat,omitempty"` + Exp float64 `json:"exp,omitempty" structs:"exp,omitempty"` + + // Default profile claims + Name string `json:"name,omitempty" structs:"name,omitempty"` + FamilyName string `json:"family_name,omitempty" structs:"family_name,omitempty"` + GivenName string `json:"given_name,omitempty" structs:"given_name,omitempty"` + MiddleName string `json:"middle_name,omitempty" structs:"middle_name,omitempty"` + NickName string `json:"nickname,omitempty" structs:"nickname,omitempty"` + PreferredUsername string `json:"preferred_username,omitempty" structs:"preferred_username,omitempty"` + Profile string `json:"profile,omitempty" structs:"profile,omitempty"` + Picture string `json:"picture,omitempty" structs:"picture,omitempty"` + Website string `json:"website,omitempty" structs:"website,omitempty"` + Gender string `json:"gender,omitempty" structs:"gender,omitempty"` + Birthdate string `json:"birthdate,omitempty" structs:"birthdate,omitempty"` + ZoneInfo string `json:"zoneinfo,omitempty" structs:"zoneinfo,omitempty"` + Locale string `json:"locale,omitempty" structs:"locale,omitempty"` + UpdatedAt string `json:"updated_at,omitempty" structs:"updated_at,omitempty"` + Email string `json:"email,omitempty" structs:"email,omitempty"` + EmailVerified bool `json:"email_verified,omitempty" structs:"email_verified"` + Phone string `json:"phone,omitempty" structs:"phone,omitempty"` + PhoneVerified bool `json:"phone_verified,omitempty" structs:"phone_verified"` + + // Custom profile claims that are provider specific + CustomClaims map[string]interface{} `json:"custom_claims,omitempty" structs:"custom_claims,omitempty"` + + // TODO: Deprecate in next major release + FullName string `json:"full_name,omitempty" structs:"full_name,omitempty"` + AvatarURL string `json:"avatar_url,omitempty" structs:"avatar_url,omitempty"` + Slug string `json:"slug,omitempty" structs:"slug,omitempty"` + ProviderId string `json:"provider_id,omitempty" structs:"provider_id,omitempty"` + UserNameKey string `json:"user_name,omitempty" structs:"user_name,omitempty"` +} + +// Email is a struct that provides information on whether an email is verified or is the primary email address +type Email struct { + Email string + Verified bool + Primary bool +} + +// UserProvidedData is a struct that contains the user's data returned from the oauth provider +type UserProvidedData struct { + Emails []Email + Metadata *Claims +} + +// Provider is an interface for interacting with external account providers +type Provider interface { + AuthCodeURL(string, ...oauth2.AuthCodeOption) string +} + +// OAuthProvider specifies additional methods needed for providers using OAuth +type OAuthProvider interface { + AuthCodeURL(string, ...oauth2.AuthCodeOption) string + GetUserData(context.Context, *oauth2.Token) (*UserProvidedData, error) + GetOAuthToken(string) (*oauth2.Token, error) +} + +func chooseHost(base, defaultHost string) string { + if base == "" { + return "https://" + defaultHost + } + + baseLen := len(base) + if base[baseLen-1] == '/' { + return base[:baseLen-1] + } + + return base +} + +func makeRequest(ctx context.Context, tok *oauth2.Token, g *oauth2.Config, url string, dst interface{}) error { + client := g.Client(ctx, tok) + client.Timeout = defaultTimeout + res, err := client.Get(url) + if err != nil { + return err + } + defer utilities.SafeClose(res.Body) + + bodyBytes, _ := io.ReadAll(res.Body) + res.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + + if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusMultipleChoices { + return httpError(res.StatusCode, string(bodyBytes)) + } + + if err := json.NewDecoder(res.Body).Decode(dst); err != nil { + return err + } + + return nil +} diff --git a/auth_v2.169.0/internal/api/provider/slack.go b/auth_v2.169.0/internal/api/provider/slack.go new file mode 100644 index 0000000..40377b0 --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/slack.go @@ -0,0 +1,94 @@ +package provider + +import ( + "context" + "strings" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const defaultSlackApiBase = "slack.com" + +type slackProvider struct { + *oauth2.Config + APIPath string +} + +type slackUser struct { + ID string `json:"https://slack.com/user_id"` + Email string `json:"email"` + Name string `json:"name"` + AvatarURL string `json:"picture"` + TeamID string `json:"https://slack.com/team_id"` +} + +// NewSlackProvider creates a Slack account provider with Legacy Slack OAuth. +func NewSlackProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + apiPath := chooseHost(ext.URL, defaultSlackApiBase) + "/api" + authPath := chooseHost(ext.URL, defaultSlackApiBase) + "/oauth" + + oauthScopes := []string{ + "profile", + "email", + "openid", + } + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + return &slackProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: authPath + "/authorize", + TokenURL: apiPath + "/oauth.access", + }, + Scopes: oauthScopes, + RedirectURL: ext.RedirectURI, + }, + APIPath: apiPath, + }, nil +} + +func (g slackProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +func (g slackProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var u slackUser + if err := makeRequest(ctx, tok, g.Config, g.APIPath+"/openid.connect.userInfo", &u); err != nil { + return nil, err + } + + data := &UserProvidedData{} + if u.Email != "" { + data.Emails = []Email{{ + Email: u.Email, + Verified: true, // Slack doesn't provide data on if email is verified. + Primary: true, + }} + } + + data.Metadata = &Claims{ + Issuer: g.APIPath, + Subject: u.ID, + Name: u.Name, + Picture: u.AvatarURL, + CustomClaims: map[string]interface{}{ + "https://slack.com/team_id": u.TeamID, + }, + + // To be deprecated + AvatarURL: u.AvatarURL, + FullName: u.Name, + ProviderId: u.ID, + } + return data, nil +} diff --git a/auth_v2.169.0/internal/api/provider/slack_oidc.go b/auth_v2.169.0/internal/api/provider/slack_oidc.go new file mode 100644 index 0000000..3c7a5eb --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/slack_oidc.go @@ -0,0 +1,99 @@ +package provider + +import ( + "context" + "strings" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const defaultSlackOIDCApiBase = "slack.com" + +type slackOIDCProvider struct { + *oauth2.Config + APIPath string +} + +type slackOIDCUser struct { + ID string `json:"https://slack.com/user_id"` + TeamID string `json:"https://slack.com/team_id"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Name string `json:"name"` + AvatarURL string `json:"picture"` +} + +// NewSlackOIDCProvider creates a Slack account provider with Sign in with Slack. +func NewSlackOIDCProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + apiPath := chooseHost(ext.URL, defaultSlackOIDCApiBase) + "/api" + authPath := chooseHost(ext.URL, defaultSlackOIDCApiBase) + "/openid" + + // these are required scopes for slack's OIDC flow + // see https://api.slack.com/authentication/sign-in-with-slack#implementation + oauthScopes := []string{ + "profile", + "email", + "openid", + } + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + return &slackOIDCProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: authPath + "/connect/authorize", + TokenURL: apiPath + "/openid.connect.token", + }, + Scopes: oauthScopes, + RedirectURL: ext.RedirectURI, + }, + APIPath: apiPath, + }, nil +} + +func (g slackOIDCProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +func (g slackOIDCProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var u slackOIDCUser + if err := makeRequest(ctx, tok, g.Config, g.APIPath+"/openid.connect.userInfo", &u); err != nil { + return nil, err + } + + data := &UserProvidedData{} + if u.Email != "" { + data.Emails = []Email{{ + Email: u.Email, + // email_verified is returned as part of the response + // see: https://api.slack.com/authentication/sign-in-with-slack#response + Verified: u.EmailVerified, + Primary: true, + }} + } + + data.Metadata = &Claims{ + Issuer: g.APIPath, + Subject: u.ID, + Name: u.Name, + Picture: u.AvatarURL, + CustomClaims: map[string]interface{}{ + "https://slack.com/team_id": u.TeamID, + }, + + // To be deprecated + AvatarURL: u.AvatarURL, + FullName: u.Name, + ProviderId: u.ID, + } + return data, nil +} diff --git a/auth_v2.169.0/internal/api/provider/spotify.go b/auth_v2.169.0/internal/api/provider/spotify.go new file mode 100644 index 0000000..e6d2f38 --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/spotify.go @@ -0,0 +1,114 @@ +package provider + +import ( + "context" + "strings" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const ( + defaultSpotifyAPIBase = "api.spotify.com/v1" // Used to get user data + defaultSpotifyAuthBase = "accounts.spotify.com" // Used for OAuth flow +) + +type spotifyProvider struct { + *oauth2.Config + APIPath string +} + +type spotifyUser struct { + DisplayName string `json:"display_name"` + Avatars []spotifyUserImage `json:"images"` + Email string `json:"email"` + ID string `json:"id"` +} + +type spotifyUserImage struct { + Url string `json:"url"` + Height int `json:"height"` + Width int `json:"width"` +} + +// NewSpotifyProvider creates a Spotify account provider. +func NewSpotifyProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + apiPath := chooseHost(ext.URL, defaultSpotifyAPIBase) + authPath := chooseHost(ext.URL, defaultSpotifyAuthBase) + + oauthScopes := []string{ + "user-read-email", + } + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + return &spotifyProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: authPath + "/authorize", + TokenURL: authPath + "/api/token", + }, + Scopes: oauthScopes, + RedirectURL: ext.RedirectURI, + }, + APIPath: apiPath, + }, nil +} + +func (g spotifyProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +func (g spotifyProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var u spotifyUser + if err := makeRequest(ctx, tok, g.Config, g.APIPath+"/me", &u); err != nil { + return nil, err + } + + data := &UserProvidedData{} + if u.Email != "" { + data.Emails = []Email{{ + Email: u.Email, + // Spotify dosen't provide data on whether the user's email is verified. + // https://developer.spotify.com/documentation/web-api/reference/get-current-users-profile + Verified: false, + Primary: true, + }} + } + + var avatarURL string + + // Spotify returns a list of avatars, we want to use the largest one + if len(u.Avatars) >= 1 { + largestAvatar := u.Avatars[0] + + for _, avatar := range u.Avatars { + if avatar.Height*avatar.Width > largestAvatar.Height*largestAvatar.Width { + largestAvatar = avatar + } + } + + avatarURL = largestAvatar.Url + } + + data.Metadata = &Claims{ + Issuer: g.APIPath, + Subject: u.ID, + Name: u.DisplayName, + Picture: avatarURL, + + // To be deprecated + AvatarURL: avatarURL, + FullName: u.DisplayName, + ProviderId: u.ID, + } + return data, nil +} diff --git a/auth_v2.169.0/internal/api/provider/twitch.go b/auth_v2.169.0/internal/api/provider/twitch.go new file mode 100644 index 0000000..defb198 --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/twitch.go @@ -0,0 +1,154 @@ +package provider + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/utilities" + "golang.org/x/oauth2" +) + +// Twitch + +const ( + defaultTwitchAuthBase = "id.twitch.tv" + defaultTwitchAPIBase = "api.twitch.tv" +) + +type twitchProvider struct { + *oauth2.Config + APIHost string +} + +type twitchUsers struct { + Data []struct { + ID string `json:"id"` + Login string `json:"login"` + DisplayName string `json:"display_name"` + Type string `json:"type"` + BroadcasterType string `json:"broadcaster_type"` + Description string `json:"description"` + ProfileImageURL string `json:"profile_image_url"` + OfflineImageURL string `json:"offline_image_url"` + ViewCount int `json:"view_count"` + Email string `json:"email"` + CreatedAt time.Time `json:"created_at"` + } `json:"data"` +} + +// NewTwitchProvider creates a Twitch account provider. +func NewTwitchProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + apiHost := chooseHost(ext.URL, defaultTwitchAPIBase) + authHost := chooseHost(ext.URL, defaultTwitchAuthBase) + + oauthScopes := []string{ + "user:read:email", + } + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + return &twitchProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: authHost + "/oauth2/authorize", + TokenURL: authHost + "/oauth2/token", + }, + RedirectURL: ext.RedirectURI, + Scopes: oauthScopes, + }, + APIHost: apiHost, + }, nil +} + +func (t twitchProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return t.Exchange(context.Background(), code) +} + +func (t twitchProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var u twitchUsers + + // Perform http request, because we neeed to set the Client-Id header + req, err := http.NewRequest("GET", t.APIHost+"/helix/users", nil) + + if err != nil { + return nil, err + } + + // set headers + req.Header.Set("Client-Id", t.Config.ClientID) + req.Header.Set("Authorization", "Bearer "+tok.AccessToken) + + client := &http.Client{Timeout: defaultTimeout} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer utilities.SafeClose(resp.Body) + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("a %v error occurred with retrieving user from twitch", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + err = json.Unmarshal(body, &u) + if err != nil { + return nil, err + } + + if len(u.Data) == 0 { + return nil, errors.New("unable to find user with twitch provider") + } + + user := u.Data[0] + + data := &UserProvidedData{} + if user.Email != "" { + data.Emails = []Email{{ + Email: user.Email, + Verified: true, + Primary: true, + }} + } + + data.Metadata = &Claims{ + Issuer: t.APIHost, + Subject: user.ID, + Picture: user.ProfileImageURL, + Name: user.Login, + NickName: user.DisplayName, + CustomClaims: map[string]interface{}{ + "broadcaster_type": user.BroadcasterType, + "description": user.Description, + "type": user.Type, + "offline_image_url": user.OfflineImageURL, + "view_count": user.ViewCount, + }, + + // To be deprecated + Slug: user.DisplayName, + AvatarURL: user.ProfileImageURL, + FullName: user.Login, + ProviderId: user.ID, + } + + return data, nil +} diff --git a/auth_v2.169.0/internal/api/provider/twitter.go b/auth_v2.169.0/internal/api/provider/twitter.go new file mode 100644 index 0000000..8dc5a4c --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/twitter.go @@ -0,0 +1,155 @@ +package provider + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/mrjones/oauth" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/utilities" + "golang.org/x/oauth2" +) + +const ( + defaultTwitterAPIBase = "api.twitter.com" + requestURL = "/oauth/request_token" + authenticateURL = "/oauth/authenticate" + tokenURL = "/oauth/access_token" //#nosec G101 -- Not a secret value. + endpointProfile = "/1.1/account/verify_credentials.json" +) + +// TwitterProvider stores the custom config for twitter provider +type TwitterProvider struct { + ClientKey string + Secret string + CallbackURL string + AuthURL string + RequestToken *oauth.RequestToken + OauthVerifier string + Consumer *oauth.Consumer + UserInfoURL string +} + +type twitterUser struct { + UserName string `json:"screen_name"` + Name string `json:"name"` + AvatarURL string `json:"profile_image_url_https"` + Email string `json:"email"` + ID string `json:"id_str"` +} + +// NewTwitterProvider creates a Twitter account provider. +func NewTwitterProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + authHost := chooseHost(ext.URL, defaultTwitterAPIBase) + p := &TwitterProvider{ + ClientKey: ext.ClientID[0], + Secret: ext.Secret, + CallbackURL: ext.RedirectURI, + UserInfoURL: authHost + endpointProfile, + } + p.Consumer = newConsumer(p, authHost) + return p, nil +} + +// GetOAuthToken is a stub method for OAuthProvider interface, unused in OAuth1.0 protocol +func (t TwitterProvider) GetOAuthToken(_ string) (*oauth2.Token, error) { + return &oauth2.Token{}, nil +} + +// GetUserData is a stub method for OAuthProvider interface, unused in OAuth1.0 protocol +func (t TwitterProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + return &UserProvidedData{}, nil +} + +// FetchUserData retrieves the user's data from the twitter provider +func (t TwitterProvider) FetchUserData(ctx context.Context, tok *oauth.AccessToken) (*UserProvidedData, error) { + var u twitterUser + resp, err := t.Consumer.Get( + t.UserInfoURL, + map[string]string{"include_entities": "false", "skip_status": "true", "include_email": "true"}, + tok, + ) + if err != nil { + return nil, err + } + defer utilities.SafeClose(resp.Body) + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return &UserProvidedData{}, fmt.Errorf("a %v error occurred with retrieving user from twitter", resp.StatusCode) + } + bits, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + _ = json.NewDecoder(bytes.NewReader(bits)).Decode(&u) + + data := &UserProvidedData{} + if u.Email != "" { + data.Emails = []Email{{ + Email: u.Email, + Verified: true, + Primary: true, + }} + } + + data.Metadata = &Claims{ + Issuer: t.UserInfoURL, + Subject: u.ID, + Name: u.Name, + Picture: u.AvatarURL, + PreferredUsername: u.UserName, + + // To be deprecated + UserNameKey: u.UserName, + FullName: u.Name, + AvatarURL: u.AvatarURL, + ProviderId: u.ID, + } + + return data, nil +} + +// AuthCodeURL fetches the request token from the twitter provider +func (t *TwitterProvider) AuthCodeURL(state string, args ...oauth2.AuthCodeOption) string { + // we do nothing with the state here as the state is passed in the requestURL step + requestToken, url, err := t.Consumer.GetRequestTokenAndUrl(t.CallbackURL + "?state=" + state) + if err != nil { + return "" + } + t.RequestToken = requestToken + t.AuthURL = url + return t.AuthURL +} + +func newConsumer(provider *TwitterProvider, authHost string) *oauth.Consumer { + c := oauth.NewConsumer( + provider.ClientKey, + provider.Secret, + oauth.ServiceProvider{ + RequestTokenUrl: authHost + requestURL, + AuthorizeTokenUrl: authHost + authenticateURL, + AccessTokenUrl: authHost + tokenURL, + }, + ) + return c +} + +// Marshal encodes the twitter request token +func (t TwitterProvider) Marshal() string { + b, _ := json.Marshal(t.RequestToken) + return string(b) +} + +// Unmarshal decodes the twitter request token +func (t TwitterProvider) Unmarshal(data string) (*oauth.RequestToken, error) { + requestToken := &oauth.RequestToken{} + err := json.NewDecoder(strings.NewReader(data)).Decode(requestToken) + return requestToken, err +} diff --git a/auth_v2.169.0/internal/api/provider/vercel_marketplace.go b/auth_v2.169.0/internal/api/provider/vercel_marketplace.go new file mode 100644 index 0000000..ba76a74 --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/vercel_marketplace.go @@ -0,0 +1,78 @@ +package provider + +import ( + "context" + "errors" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const ( + defaultVercelMarketplaceAPIBase = "api.vercel.com" + IssuerVercelMarketplace = "https://marketplace.vercel.com" +) + +type vercelMarketplaceProvider struct { + *oauth2.Config + oidc *oidc.Provider + APIPath string +} + +// NewVercelMarketplaceProvider creates a VercelMarketplace account provider via OIDC. +func NewVercelMarketplaceProvider(ext conf.OAuthProviderConfiguration, scopes string) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + apiPath := chooseHost(ext.URL, defaultVercelMarketplaceAPIBase) + + oauthScopes := []string{} + + if scopes != "" { + oauthScopes = append(oauthScopes, strings.Split(scopes, ",")...) + } + + oidcProvider, err := oidc.NewProvider(context.Background(), IssuerVercelMarketplace) + if err != nil { + return nil, err + } + + return &vercelMarketplaceProvider{ + oidc: oidcProvider, + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: apiPath + "/oauth/v2/authorization", + TokenURL: apiPath + "/oauth/v2/accessToken", + }, + Scopes: oauthScopes, + RedirectURL: ext.RedirectURI, + }, + APIPath: apiPath, + }, nil +} + +func (g vercelMarketplaceProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +func (g vercelMarketplaceProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + idToken := tok.Extra("id_token") + if tok.AccessToken == "" || idToken == nil { + return nil, errors.New("vercel_marketplace: no OIDC ID token present in response") + } + + _, data, err := ParseIDToken(ctx, g.oidc, &oidc.Config{ + ClientID: g.ClientID, + }, idToken.(string), ParseIDTokenOptions{ + AccessToken: tok.AccessToken, + }) + if err != nil { + return nil, err + } + return data, nil +} diff --git a/auth_v2.169.0/internal/api/provider/workos.go b/auth_v2.169.0/internal/api/provider/workos.go new file mode 100644 index 0000000..75cafa2 --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/workos.go @@ -0,0 +1,98 @@ +package provider + +import ( + "context" + "strings" + + "github.com/mitchellh/mapstructure" + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const ( + defaultWorkOSAPIBase = "api.workos.com" +) + +type workosProvider struct { + *oauth2.Config + APIPath string +} + +// See https://workos.com/docs/reference/sso/profile. +type workosUser struct { + ID string `mapstructure:"id"` + ConnectionID string `mapstructure:"connection_id"` + OrganizationID string `mapstructure:"organization_id"` + ConnectionType string `mapstructure:"connection_type"` + Email string `mapstructure:"email"` + FirstName string `mapstructure:"first_name"` + LastName string `mapstructure:"last_name"` + Object string `mapstructure:"object"` + IdpID string `mapstructure:"idp_id"` + RawAttributes map[string]interface{} `mapstructure:"raw_attributes"` +} + +// NewWorkOSProvider creates a WorkOS account provider. +func NewWorkOSProvider(ext conf.OAuthProviderConfiguration) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + apiPath := chooseHost(ext.URL, defaultWorkOSAPIBase) + + return &workosProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: apiPath + "/sso/authorize", + TokenURL: apiPath + "/sso/token", + }, + RedirectURL: ext.RedirectURI, + }, + APIPath: apiPath, + }, nil +} + +func (g workosProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +func (g workosProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + if tok.AccessToken == "" { + return &UserProvidedData{}, nil + } + + // WorkOS API returns the user's profile data along with the OAuth2 token, so + // we can just convert from `map[string]interface{}` to `workosUser` without + // an additional network request. + var u workosUser + err := mapstructure.Decode(tok.Extra("profile"), &u) + if err != nil { + return nil, err + } + + data := &UserProvidedData{} + if u.Email != "" { + data.Emails = []Email{{ + Email: u.Email, + Verified: true, + Primary: true, + }} + } + + data.Metadata = &Claims{ + Issuer: g.APIPath, + Subject: u.ID, + Name: strings.TrimSpace(u.FirstName + " " + u.LastName), + CustomClaims: map[string]interface{}{ + "connection_id": u.ConnectionID, + "organization_id": u.OrganizationID, + }, + + // To be deprecated + FullName: strings.TrimSpace(u.FirstName + " " + u.LastName), + ProviderId: u.ID, + } + + return data, nil +} diff --git a/auth_v2.169.0/internal/api/provider/zoom.go b/auth_v2.169.0/internal/api/provider/zoom.go new file mode 100644 index 0000000..8e2e9fa --- /dev/null +++ b/auth_v2.169.0/internal/api/provider/zoom.go @@ -0,0 +1,91 @@ +package provider + +import ( + "context" + "strings" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/oauth2" +) + +const ( + defaultZoomAuthBase = "zoom.us" + defaultZoomAPIBase = "api.zoom.us" +) + +type zoomProvider struct { + *oauth2.Config + APIPath string +} + +type zoomUser struct { + ID string `json:"id"` + FirstName string `json:"first_name"` + LastName string `json:"last_name"` + Email string `json:"email"` + EmailVerified int `json:"verified"` + LoginType string `json:"login_type"` + AvatarURL string `json:"pic_url"` +} + +// NewZoomProvider creates a Zoom account provider. +func NewZoomProvider(ext conf.OAuthProviderConfiguration) (OAuthProvider, error) { + if err := ext.ValidateOAuth(); err != nil { + return nil, err + } + + apiPath := chooseHost(ext.URL, defaultZoomAPIBase) + "/v2" + authPath := chooseHost(ext.URL, defaultZoomAuthBase) + "/oauth" + + return &zoomProvider{ + Config: &oauth2.Config{ + ClientID: ext.ClientID[0], + ClientSecret: ext.Secret, + Endpoint: oauth2.Endpoint{ + AuthURL: authPath + "/authorize", + TokenURL: authPath + "/token", + }, + RedirectURL: ext.RedirectURI, + }, + APIPath: apiPath, + }, nil +} + +func (g zoomProvider) GetOAuthToken(code string) (*oauth2.Token, error) { + return g.Exchange(context.Background(), code) +} + +func (g zoomProvider) GetUserData(ctx context.Context, tok *oauth2.Token) (*UserProvidedData, error) { + var u zoomUser + if err := makeRequest(ctx, tok, g.Config, g.APIPath+"/users/me", &u); err != nil { + return nil, err + } + + data := &UserProvidedData{} + if u.Email != "" { + email := Email{} + email.Email = u.Email + email.Primary = true + // A login_type of "100" refers to email-based logins, not oauth. + // A user is verified (type 1) only if they received an email when their profile was created and confirmed the link. + // A zoom user will only be sent an email confirmation link if they signed up using their zoom work email and not oauth. + // See: https://devforum.zoom.us/t/how-to-determine-if-a-zoom-user-actually-owns-their-email-address/44430 + if u.LoginType != "100" || u.EmailVerified != 0 { + email.Verified = true + } + data.Emails = []Email{email} + } + + data.Metadata = &Claims{ + Issuer: g.APIPath, + Subject: u.ID, + Name: strings.TrimSpace(u.FirstName + " " + u.LastName), + Picture: u.AvatarURL, + + // To be deprecated + AvatarURL: u.AvatarURL, + FullName: strings.TrimSpace(u.FirstName + " " + u.LastName), + ProviderId: u.ID, + } + return data, nil +} diff --git a/auth_v2.169.0/internal/api/reauthenticate.go b/auth_v2.169.0/internal/api/reauthenticate.go new file mode 100644 index 0000000..5146ae4 --- /dev/null +++ b/auth_v2.169.0/internal/api/reauthenticate.go @@ -0,0 +1,97 @@ +package api + +import ( + "net/http" + + "github.com/supabase/auth/internal/api/sms_provider" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +const InvalidNonceMessage = "Nonce has expired or is invalid" + +// Reauthenticate sends a reauthentication otp to either the user's email or phone +func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + user := getUser(ctx) + email, phone := user.GetEmail(), user.GetPhone() + + if email == "" && phone == "" { + return badRequestError(ErrorCodeValidationFailed, "Reauthentication requires the user to have an email or a phone number") + } + + if email != "" { + if !user.IsConfirmed() { + return unprocessableEntityError(ErrorCodeEmailNotConfirmed, "Please verify your email first.") + } + } else if phone != "" { + if !user.IsPhoneConfirmed() { + return unprocessableEntityError(ErrorCodePhoneNotConfirmed, "Please verify your phone first.") + } + } + + messageID := "" + err := db.Transaction(func(tx *storage.Connection) error { + if terr := models.NewAuditLogEntry(r, tx, user, models.UserReauthenticateAction, "", nil); terr != nil { + return terr + } + if email != "" { + return a.sendReauthenticationOtp(r, tx, user) + } else if phone != "" { + mID, err := a.sendPhoneConfirmation(r, tx, user, phone, phoneReauthenticationOtp, sms_provider.SMSProvider) + if err != nil { + return err + } + + messageID = mID + } + return nil + }) + if err != nil { + return err + } + + ret := map[string]any{} + if messageID != "" { + ret["message_id"] = messageID + + } + + return sendJSON(w, http.StatusOK, ret) +} + +// verifyReauthentication checks if the nonce provided is valid +func (a *API) verifyReauthentication(nonce string, tx *storage.Connection, config *conf.GlobalConfiguration, user *models.User) error { + if user.ReauthenticationToken == "" || user.ReauthenticationSentAt == nil { + return unprocessableEntityError(ErrorCodeReauthenticationNotValid, InvalidNonceMessage) + } + var isValid bool + if user.GetEmail() != "" { + tokenHash := crypto.GenerateTokenHash(user.GetEmail(), nonce) + isValid = isOtpValid(tokenHash, user.ReauthenticationToken, user.ReauthenticationSentAt, config.Mailer.OtpExp) + } else if user.GetPhone() != "" { + if config.Sms.IsTwilioVerifyProvider() { + smsProvider, _ := sms_provider.GetSmsProvider(*config) + if err := smsProvider.(*sms_provider.TwilioVerifyProvider).VerifyOTP(string(user.Phone), nonce); err != nil { + return forbiddenError(ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) + } + return nil + } else { + tokenHash := crypto.GenerateTokenHash(user.GetPhone(), nonce) + isValid = isOtpValid(tokenHash, user.ReauthenticationToken, user.ReauthenticationSentAt, config.Sms.OtpExp) + } + } else { + return unprocessableEntityError(ErrorCodeReauthenticationNotValid, "Reauthentication requires an email or a phone number") + } + if !isValid { + return unprocessableEntityError(ErrorCodeReauthenticationNotValid, InvalidNonceMessage) + } + if err := user.ConfirmReauthentication(tx); err != nil { + return internalServerError("Error during reauthentication").WithInternalError(err) + } + return nil +} diff --git a/auth_v2.169.0/internal/api/recover.go b/auth_v2.169.0/internal/api/recover.go new file mode 100644 index 0000000..7c03c32 --- /dev/null +++ b/auth_v2.169.0/internal/api/recover.go @@ -0,0 +1,73 @@ +package api + +import ( + "net/http" + + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +// RecoverParams holds the parameters for a password recovery request +type RecoverParams struct { + Email string `json:"email"` + CodeChallenge string `json:"code_challenge"` + CodeChallengeMethod string `json:"code_challenge_method"` +} + +func (p *RecoverParams) Validate(a *API) error { + if p.Email == "" { + return badRequestError(ErrorCodeValidationFailed, "Password recovery requires an email") + } + var err error + if p.Email, err = a.validateEmail(p.Email); err != nil { + return err + } + if err := validatePKCEParams(p.CodeChallengeMethod, p.CodeChallenge); err != nil { + return err + } + return nil +} + +// Recover sends a recovery email +func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + params := &RecoverParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + flowType := getFlowFromChallenge(params.CodeChallenge) + if err := params.Validate(a); err != nil { + return err + } + + var user *models.User + var err error + aud := a.requestAud(ctx, r) + + user, err = models.FindUserByEmailAndAudience(db, params.Email, aud) + if err != nil { + if models.IsNotFoundError(err) { + return sendJSON(w, http.StatusOK, map[string]string{}) + } + return internalServerError("Unable to process request").WithInternalError(err) + } + if isPKCEFlow(flowType) { + if _, err := generateFlowState(db, models.Recovery.String(), models.Recovery, params.CodeChallengeMethod, params.CodeChallenge, &(user.ID)); err != nil { + return err + } + } + + err = db.Transaction(func(tx *storage.Connection) error { + if terr := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil { + return terr + } + return a.sendPasswordRecovery(r, tx, user, flowType) + }) + if err != nil { + return err + } + + return sendJSON(w, http.StatusOK, map[string]string{}) +} diff --git a/auth_v2.169.0/internal/api/recover_test.go b/auth_v2.169.0/internal/api/recover_test.go new file mode 100644 index 0000000..a7e655c --- /dev/null +++ b/auth_v2.169.0/internal/api/recover_test.go @@ -0,0 +1,153 @@ +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +type RecoverTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestRecover(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &RecoverTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *RecoverTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + // Create user + u, err := models.NewUser("", "test@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") +} + +func (ts *RecoverTestSuite) TestRecover_FirstRecovery() { + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + u.RecoverySentAt = &time.Time{} + require.NoError(ts.T(), ts.API.db.Update(u)) + + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + })) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/recover", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + assert.WithinDuration(ts.T(), time.Now(), *u.RecoverySentAt, 1*time.Second) +} + +func (ts *RecoverTestSuite) TestRecover_NoEmailSent() { + recoveryTime := time.Now().UTC().Add(-59 * time.Second) + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + u.RecoverySentAt = &recoveryTime + require.NoError(ts.T(), ts.API.db.Update(u)) + + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + })) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/recover", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusTooManyRequests, w.Code) + + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + // ensure it did not send a new email + u1 := recoveryTime.Round(time.Second).Unix() + u2 := u.RecoverySentAt.Round(time.Second).Unix() + assert.Equal(ts.T(), u1, u2) +} + +func (ts *RecoverTestSuite) TestRecover_NewEmailSent() { + recoveryTime := time.Now().UTC().Add(-20 * time.Minute) + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + u.RecoverySentAt = &recoveryTime + require.NoError(ts.T(), ts.API.db.Update(u)) + + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + })) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/recover", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + // ensure it sent a new email + assert.WithinDuration(ts.T(), time.Now(), *u.RecoverySentAt, 1*time.Second) +} + +func (ts *RecoverTestSuite) TestRecover_NoSideChannelLeak() { + email := "doesntexist@example.com" + + _, err := models.FindUserByEmailAndAudience(ts.API.db, email, ts.Config.JWT.Aud) + require.True(ts.T(), models.IsNotFoundError(err), "User with email %s does exist", email) + + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": email, + })) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/recover", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) +} diff --git a/auth_v2.169.0/internal/api/resend.go b/auth_v2.169.0/internal/api/resend.go new file mode 100644 index 0000000..2c30536 --- /dev/null +++ b/auth_v2.169.0/internal/api/resend.go @@ -0,0 +1,154 @@ +package api + +import ( + "net/http" + + "github.com/supabase/auth/internal/api/sms_provider" + mail "github.com/supabase/auth/internal/mailer" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +// ResendConfirmationParams holds the parameters for a resend request +type ResendConfirmationParams struct { + Type string `json:"type"` + Email string `json:"email"` + Phone string `json:"phone"` +} + +func (p *ResendConfirmationParams) Validate(a *API) error { + config := a.config + + switch p.Type { + case mail.SignupVerification, mail.EmailChangeVerification, smsVerification, phoneChangeVerification: + break + default: + // type does not match one of the above + return badRequestError(ErrorCodeValidationFailed, "Missing one of these types: signup, email_change, sms, phone_change") + + } + if p.Email == "" && p.Type == mail.SignupVerification { + return badRequestError(ErrorCodeValidationFailed, "Type provided requires an email address") + } + if p.Phone == "" && p.Type == smsVerification { + return badRequestError(ErrorCodeValidationFailed, "Type provided requires a phone number") + } + + var err error + if p.Email != "" && p.Phone != "" { + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided.") + } else if p.Email != "" { + if !config.External.Email.Enabled { + return badRequestError(ErrorCodeEmailProviderDisabled, "Email logins are disabled") + } + p.Email, err = a.validateEmail(p.Email) + if err != nil { + return err + } + } else if p.Phone != "" { + if !config.External.Phone.Enabled { + return badRequestError(ErrorCodePhoneProviderDisabled, "Phone logins are disabled") + } + p.Phone, err = validatePhone(p.Phone) + if err != nil { + return err + } + } else { + // both email and phone are empty + return badRequestError(ErrorCodeValidationFailed, "Missing email address or phone number") + } + return nil +} + +// Recover sends a recovery email +func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + params := &ResendConfirmationParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + if err := params.Validate(a); err != nil { + return err + } + + var user *models.User + var err error + aud := a.requestAud(ctx, r) + if params.Email != "" { + user, err = models.FindUserByEmailAndAudience(db, params.Email, aud) + } else if params.Phone != "" { + user, err = models.FindUserByPhoneAndAudience(db, params.Phone, aud) + } + + if err != nil { + if models.IsNotFoundError(err) { + return sendJSON(w, http.StatusOK, map[string]string{}) + } + return internalServerError("Unable to process request").WithInternalError(err) + } + + switch params.Type { + case mail.SignupVerification: + if user.IsConfirmed() { + // if the user's email is confirmed already, we don't need to send a confirmation email again + return sendJSON(w, http.StatusOK, map[string]string{}) + } + case smsVerification: + if user.IsPhoneConfirmed() { + // if the user's phone is confirmed already, we don't need to send a confirmation sms again + return sendJSON(w, http.StatusOK, map[string]string{}) + } + case mail.EmailChangeVerification: + // do not resend if user doesn't have a new email address + if user.EmailChange == "" { + return sendJSON(w, http.StatusOK, map[string]string{}) + } + case phoneChangeVerification: + // do not resend if user doesn't have a new phone number + if user.PhoneChange == "" { + return sendJSON(w, http.StatusOK, map[string]string{}) + } + } + + messageID := "" + err = db.Transaction(func(tx *storage.Connection) error { + switch params.Type { + case mail.SignupVerification: + if terr := models.NewAuditLogEntry(r, tx, user, models.UserConfirmationRequestedAction, "", nil); terr != nil { + return terr + } + // PKCE not implemented yet + return a.sendConfirmation(r, tx, user, models.ImplicitFlow) + case smsVerification: + if terr := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil { + return terr + } + mID, terr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneConfirmationOtp, sms_provider.SMSProvider) + if terr != nil { + return terr + } + messageID = mID + case mail.EmailChangeVerification: + return a.sendEmailChange(r, tx, user, user.EmailChange, models.ImplicitFlow) + case phoneChangeVerification: + mID, terr := a.sendPhoneConfirmation(r, tx, user, user.PhoneChange, phoneChangeVerification, sms_provider.SMSProvider) + if terr != nil { + return terr + } + messageID = mID + } + return nil + }) + if err != nil { + return err + } + + ret := map[string]any{} + if messageID != "" { + ret["message_id"] = messageID + } + + return sendJSON(w, http.StatusOK, ret) +} diff --git a/auth_v2.169.0/internal/api/resend_test.go b/auth_v2.169.0/internal/api/resend_test.go new file mode 100644 index 0000000..83c58c4 --- /dev/null +++ b/auth_v2.169.0/internal/api/resend_test.go @@ -0,0 +1,217 @@ +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + mail "github.com/supabase/auth/internal/mailer" + "github.com/supabase/auth/internal/models" +) + +type ResendTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestResend(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &ResendTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *ResendTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) +} + +func (ts *ResendTestSuite) TestResendValidation() { + cases := []struct { + desc string + params map[string]interface{} + expected map[string]interface{} + }{ + { + desc: "Invalid type", + params: map[string]interface{}{ + "type": "invalid", + "email": "foo@example.com", + }, + expected: map[string]interface{}{ + "code": http.StatusBadRequest, + "message": "Missing one of these types: signup, email_change, sms, phone_change", + }, + }, + { + desc: "Type & email mismatch", + params: map[string]interface{}{ + "type": "sms", + "email": "foo@example.com", + }, + expected: map[string]interface{}{ + "code": http.StatusBadRequest, + "message": "Type provided requires a phone number", + }, + }, + { + desc: "Phone & email change type", + params: map[string]interface{}{ + "type": "email_change", + "phone": "+123456789", + }, + expected: map[string]interface{}{ + "code": http.StatusOK, + "message": nil, + }, + }, + { + desc: "Email & phone number provided", + params: map[string]interface{}{ + "type": "email_change", + "phone": "+123456789", + "email": "foo@example.com", + }, + expected: map[string]interface{}{ + "code": http.StatusBadRequest, + "message": "Only an email address or phone number should be provided.", + }, + }, + } + for _, c := range cases { + ts.Run(c.desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.params)) + req := httptest.NewRequest(http.MethodPost, "http://localhost/resend", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), c.expected["code"], w.Code) + + data := make(map[string]interface{}) + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.Equal(ts.T(), c.expected["message"], data["msg"]) + }) + } + +} + +func (ts *ResendTestSuite) TestResendSuccess() { + // Create user + u, err := models.NewUser("123456789", "foo@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + + // Avoid max freq limit error + now := time.Now().Add(-1 * time.Minute) + + // Enable Phone Logoin for phone related tests + ts.Config.External.Phone.Enabled = true + // disable secure email change + ts.Config.Mailer.SecureEmailChangeEnabled = false + + u.ConfirmationToken = "123456" + u.ConfirmationSentAt = &now + u.EmailChange = "bar@example.com" + u.EmailChangeSentAt = &now + u.EmailChangeTokenNew = "123456" + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.EmailChange, u.EmailChangeTokenNew, models.EmailChangeTokenNew)) + + phoneUser, err := models.NewUser("1234567890", "", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + phoneUser.EmailChange = "bar@example.com" + phoneUser.EmailChangeSentAt = &now + phoneUser.EmailChangeTokenNew = "123456" + require.NoError(ts.T(), ts.API.db.Create(phoneUser), "Error saving new test user") + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, phoneUser.ID, phoneUser.EmailChange, phoneUser.EmailChangeTokenNew, models.EmailChangeTokenNew)) + + emailUser, err := models.NewUser("", "bar@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + phoneUser.PhoneChange = "1234567890" + phoneUser.PhoneChangeSentAt = &now + phoneUser.PhoneChangeToken = "123456" + require.NoError(ts.T(), ts.API.db.Create(emailUser), "Error saving new test user") + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, phoneUser.ID, phoneUser.PhoneChange, phoneUser.PhoneChangeToken, models.PhoneChangeToken)) + + cases := []struct { + desc string + params map[string]interface{} + // expected map[string]interface{} + user *models.User + }{ + { + desc: "Resend signup confirmation", + params: map[string]interface{}{ + "type": "signup", + "email": u.GetEmail(), + }, + user: u, + }, + { + desc: "Resend email change", + params: map[string]interface{}{ + "type": "email_change", + "email": u.GetEmail(), + }, + user: u, + }, + { + desc: "Resend email change for phone user", + params: map[string]interface{}{ + "type": "email_change", + "phone": phoneUser.GetPhone(), + }, + user: phoneUser, + }, + { + desc: "Resend phone change for email user", + params: map[string]interface{}{ + "type": "phone_change", + "email": emailUser.GetEmail(), + }, + user: emailUser, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.params)) + req := httptest.NewRequest(http.MethodPost, "http://localhost/resend", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + switch c.params["type"] { + case mail.SignupVerification, mail.EmailChangeVerification: + dbUser, err := models.FindUserByID(ts.API.db, c.user.ID) + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), dbUser) + + if c.params["type"] == mail.SignupVerification { + require.NotEqual(ts.T(), dbUser.ConfirmationToken, c.user.ConfirmationToken) + require.NotEqual(ts.T(), dbUser.ConfirmationSentAt, c.user.ConfirmationSentAt) + } else if c.params["type"] == mail.EmailChangeVerification { + require.NotEqual(ts.T(), dbUser.EmailChangeTokenNew, c.user.EmailChangeTokenNew) + require.NotEqual(ts.T(), dbUser.EmailChangeSentAt, c.user.EmailChangeSentAt) + } + } + }) + } +} diff --git a/auth_v2.169.0/internal/api/router.go b/auth_v2.169.0/internal/api/router.go new file mode 100644 index 0000000..1feb66d --- /dev/null +++ b/auth_v2.169.0/internal/api/router.go @@ -0,0 +1,92 @@ +package api + +import ( + "context" + "net/http" + + "github.com/go-chi/chi/v5" +) + +func newRouter() *router { + return &router{chi.NewRouter()} +} + +type router struct { + chi chi.Router +} + +func (r *router) Route(pattern string, fn func(*router)) { + r.chi.Route(pattern, func(c chi.Router) { + fn(&router{c}) + }) +} + +func (r *router) Get(pattern string, fn apiHandler) { + r.chi.Get(pattern, handler(fn)) +} +func (r *router) Post(pattern string, fn apiHandler) { + r.chi.Post(pattern, handler(fn)) +} +func (r *router) Put(pattern string, fn apiHandler) { + r.chi.Put(pattern, handler(fn)) +} +func (r *router) Delete(pattern string, fn apiHandler) { + r.chi.Delete(pattern, handler(fn)) +} + +func (r *router) With(fn middlewareHandler) *router { + c := r.chi.With(middleware(fn)) + return &router{c} +} + +func (r *router) WithBypass(fn func(next http.Handler) http.Handler) *router { + c := r.chi.With(fn) + return &router{c} +} + +func (r *router) Use(fn middlewareHandler) { + r.chi.Use(middleware(fn)) +} +func (r *router) UseBypass(fn func(next http.Handler) http.Handler) { + r.chi.Use(fn) +} + +func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { + r.chi.ServeHTTP(w, req) +} + +type apiHandler func(w http.ResponseWriter, r *http.Request) error + +func handler(fn apiHandler) http.HandlerFunc { + return fn.serve +} + +func (h apiHandler) serve(w http.ResponseWriter, r *http.Request) { + if err := h(w, r); err != nil { + HandleResponseError(err, w, r) + } +} + +type middlewareHandler func(w http.ResponseWriter, r *http.Request) (context.Context, error) + +func (m middlewareHandler) handler(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m.serve(next, w, r) + }) +} + +func (m middlewareHandler) serve(next http.Handler, w http.ResponseWriter, r *http.Request) { + ctx, err := m(w, r) + if err != nil { + HandleResponseError(err, w, r) + return + } + if ctx != nil { + r = r.WithContext(ctx) + } + next.ServeHTTP(w, r) +} + +func middleware(fn middlewareHandler) func(http.Handler) http.Handler { + return fn.handler +} diff --git a/auth_v2.169.0/internal/api/saml.go b/auth_v2.169.0/internal/api/saml.go new file mode 100644 index 0000000..f32d443 --- /dev/null +++ b/auth_v2.169.0/internal/api/saml.go @@ -0,0 +1,113 @@ +package api + +import ( + "encoding/xml" + "net/http" + "net/url" + "strings" + "time" + + "github.com/crewjam/saml" + "github.com/crewjam/saml/samlsp" +) + +// getSAMLServiceProvider generates a new service provider object with the +// (optionally) provided descriptor (metadata) for the identity provider. +func (a *API) getSAMLServiceProvider(identityProvider *saml.EntityDescriptor, idpInitiated bool) *saml.ServiceProvider { + var externalURL *url.URL + + if a.config.SAML.ExternalURL != "" { + url, err := url.ParseRequestURI(a.config.SAML.ExternalURL) + if err != nil { + // this should not fail as a.config should have been validated using #Validate() + panic(err) + } + + externalURL = url + } else { + url, err := url.ParseRequestURI(a.config.API.ExternalURL) + if err != nil { + // this should not fail as a.config should have been validated using #Validate() + panic(err) + } + + externalURL = url + } + + if !strings.HasSuffix(externalURL.Path, "/") { + externalURL.Path += "/" + } + + externalURL.Path += "sso/" + + provider := samlsp.DefaultServiceProvider(samlsp.Options{ + URL: *externalURL, + Key: a.config.SAML.RSAPrivateKey, + Certificate: a.config.SAML.Certificate, + SignRequest: true, + AllowIDPInitiated: idpInitiated, + IDPMetadata: identityProvider, + }) + + provider.AuthnNameIDFormat = saml.PersistentNameIDFormat + + return &provider +} + +// SAMLMetadata serves GoTrue's SAML Service Provider metadata file. +func (a *API) SAMLMetadata(w http.ResponseWriter, r *http.Request) error { + serviceProvider := a.getSAMLServiceProvider(nil, true) + + metadata := serviceProvider.Metadata() + + if r.FormValue("download") == "true" { + // 5 year expiration, comparable to what GSuite does + metadata.ValidUntil = time.Now().UTC().AddDate(5, 0, 0) + } + + for i := range metadata.SPSSODescriptors { + // we set this to false since the IdP initiated flow can only + // sign the Assertion, and not the full Request + // unfortunately this is hardcoded in the crewjam library if + // signatures (instead of encryption) are supported + // https://github.com/crewjam/saml/blob/v0.4.8/service_provider.go#L217 + metadata.SPSSODescriptors[i].AuthnRequestsSigned = nil + + // advertize the requested NameID formats (either persistent or email address) + metadata.SPSSODescriptors[i].NameIDFormats = []saml.NameIDFormat{ + saml.EmailAddressNameIDFormat, + saml.PersistentNameIDFormat, + } + } + + for i := range metadata.SPSSODescriptors { + spd := &metadata.SPSSODescriptors[i] + + var keyDescriptors []saml.KeyDescriptor + + for _, kd := range spd.KeyDescriptors { + // only advertize key as usable for encryption if allowed + if kd.Use == "signing" || (a.config.SAML.AllowEncryptedAssertions && kd.Use == "encryption") { + keyDescriptors = append(keyDescriptors, kd) + } + } + + spd.KeyDescriptors = keyDescriptors + } + + metadataXML, err := xml.Marshal(metadata) + if err != nil { + return err + } + + w.Header().Set("Content-Type", "application/xml") + w.Header().Set("Cache-Control", "public, max-age=600") // cache at CDN for 10 minutes + + if r.FormValue("download") == "true" { + w.Header().Set("Content-Disposition", "attachment; filename=\"metadata.xml\"") + } + + _, err = w.Write(metadataXML) + + return err +} diff --git a/auth_v2.169.0/internal/api/saml_test.go b/auth_v2.169.0/internal/api/saml_test.go new file mode 100644 index 0000000..a290fb2 --- /dev/null +++ b/auth_v2.169.0/internal/api/saml_test.go @@ -0,0 +1,59 @@ +package api + +import ( + tst "testing" + "time" + + "encoding/xml" + "net/http" + "net/http/httptest" + + "github.com/crewjam/saml" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/conf" +) + +func TestSAMLMetadataWithAPI(t *tst.T) { + config, err := conf.LoadGlobal(apiTestConfig) + require.NoError(t, err) + config.API.ExternalURL = "https://projectref.supabase.co/auth/v1/" + config.SAML.Enabled = true + config.SAML.PrivateKey = "MIIEowIBAAKCAQEAszrVveMQcSsa0Y+zN1ZFb19cRS0jn4UgIHTprW2tVBmO2PABzjY3XFCfx6vPirMAPWBYpsKmXrvm1tr0A6DZYmA8YmJd937VUQ67fa6DMyppBYTjNgGEkEhmKuszvF3MARsIKCGtZqUrmS7UG4404wYxVppnr2EYm3RGtHlkYsXu20MBqSDXP47bQP+PkJqC3BuNGk3xt5UHl2FSFpTHelkI6lBynw16B+lUT1F96SERNDaMqi/TRsZdGe5mB/29ngC/QBMpEbRBLNRir5iUevKS7Pn4aph9Qjaxx/97siktK210FJT23KjHpgcUfjoQ6BgPBTLtEeQdRyDuc/CgfwIDAQABAoIBAGYDWOEpupQPSsZ4mjMnAYJwrp4ZISuMpEqVAORbhspVeb70bLKonT4IDcmiexCg7cQBcLQKGpPVM4CbQ0RFazXZPMVq470ZDeWDEyhoCfk3bGtdxc1Zc9CDxNMs6FeQs6r1beEZug6weG5J/yRn/qYxQife3qEuDMl+lzfl2EN3HYVOSnBmdt50dxRuX26iW3nqqbMRqYn9OHuJ1LvRRfYeyVKqgC5vgt/6Tf7DAJwGe0dD7q08byHV8DBZ0pnMVU0bYpf1GTgMibgjnLjK//EVWafFHtN+RXcjzGmyJrk3+7ZyPUpzpDjO21kpzUQLrpEkkBRnmg6bwHnSrBr8avECgYEA3pq1PTCAOuLQoIm1CWR9/dhkbJQiKTJevlWV8slXQLR50P0WvI2RdFuSxlWmA4xZej8s4e7iD3MYye6SBsQHygOVGc4efvvEZV8/XTlDdyj7iLVGhnEmu2r7AFKzy8cOvXx0QcLg+zNd7vxZv/8D3Qj9Jje2LjLHKM5n/dZ3RzUCgYEAzh5Lo2anc4WN8faLGt7rPkGQF+7/18ImQE11joHWa3LzAEy7FbeOGpE/vhOv5umq5M/KlWFIRahMEQv4RusieHWI19ZLIP+JwQFxWxS+cPp3xOiGcquSAZnlyVSxZ//dlVgaZq2o2MfrxECcovRlaknl2csyf+HjFFwKlNxHm2MCgYAr//R3BdEy0oZeVRndo2lr9YvUEmu2LOihQpWDCd0fQw0ZDA2kc28eysL2RROte95r1XTvq6IvX5a0w11FzRWlDpQ4J4/LlcQ6LVt+98SoFwew+/PWuyLmxLycUbyMOOpm9eSc4wJJZNvaUzMCSkvfMtmm5jgyZYMMQ9A2Ul/9SQKBgB9mfh9mhBwVPIqgBJETZMMXOdxrjI5SBYHGSyJqpT+5Q0vIZLfqPrvNZOiQFzwWXPJ+tV4Mc/YorW3rZOdo6tdvEGnRO6DLTTEaByrY/io3/gcBZXoSqSuVRmxleqFdWWRnB56c1hwwWLqNHU+1671FhL6pNghFYVK4suP6qu4BAoGBAMk+VipXcIlD67mfGrET/xDqiWWBZtgTzTMjTpODhDY1GZck1eb4CQMP5j5V3gFJ4cSgWDJvnWg8rcz0unz/q4aeMGl1rah5WNDWj1QKWMS6vJhMHM/rqN1WHWR0ZnV83svYgtg0zDnQKlLujqW4JmGXLMU7ur6a+e6lpa1fvLsP" + config.API.MaxRequestDuration = 5 * time.Second + + require.NoError(t, config.ApplyDefaults()) + require.NoError(t, config.SAML.PopulateFields(config.API.ExternalURL)) + + require.NotNil(t, config.SAML.Certificate) + + api := NewAPI(config, nil) + + // Setup request + req := httptest.NewRequest(http.MethodGet, "http://localhost/sso/saml/metadata", nil) + + w := httptest.NewRecorder() + api.handler.ServeHTTP(w, req) + require.Equal(t, w.Code, http.StatusOK) + + metadata := saml.EntityDescriptor{} + require.NoError(t, xml.Unmarshal(w.Body.Bytes(), &metadata)) + + require.Equal(t, metadata.EntityID, "https://projectref.supabase.co/auth/v1/sso/saml/metadata") + require.Equal(t, len(metadata.SPSSODescriptors), 1) + + require.Nil(t, metadata.SPSSODescriptors[0].AuthnRequestsSigned) + require.True(t, *(metadata.SPSSODescriptors[0].WantAssertionsSigned)) + + require.Equal(t, len(metadata.SPSSODescriptors[0].AssertionConsumerServices), 2) + require.Equal(t, metadata.SPSSODescriptors[0].AssertionConsumerServices[0].Location, "https://projectref.supabase.co/auth/v1/sso/saml/acs") + require.Equal(t, metadata.SPSSODescriptors[0].AssertionConsumerServices[1].Location, "https://projectref.supabase.co/auth/v1/sso/saml/acs") + require.Equal(t, len(metadata.SPSSODescriptors[0].SingleLogoutServices), 1) + require.Equal(t, metadata.SPSSODescriptors[0].SingleLogoutServices[0].Location, "https://projectref.supabase.co/auth/v1/sso/saml/slo") + + require.Equal(t, len(metadata.SPSSODescriptors[0].KeyDescriptors), 1) + require.Equal(t, metadata.SPSSODescriptors[0].KeyDescriptors[0].Use, "signing") + + require.Equal(t, len(metadata.SPSSODescriptors[0].NameIDFormats), 2) + require.Equal(t, metadata.SPSSODescriptors[0].NameIDFormats[0], saml.EmailAddressNameIDFormat) + require.Equal(t, metadata.SPSSODescriptors[0].NameIDFormats[1], saml.PersistentNameIDFormat) +} diff --git a/auth_v2.169.0/internal/api/samlacs.go b/auth_v2.169.0/internal/api/samlacs.go new file mode 100644 index 0000000..8627f93 --- /dev/null +++ b/auth_v2.169.0/internal/api/samlacs.go @@ -0,0 +1,327 @@ +package api + +import ( + "context" + "encoding/base64" + "encoding/json" + "encoding/xml" + "net/http" + "net/url" + "time" + + "github.com/crewjam/saml" + "github.com/fatih/structs" + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" +) + +func (a *API) samlDestroyRelayState(ctx context.Context, relayState *models.SAMLRelayState) error { + db := a.db.WithContext(ctx) + + // It's OK to destroy the RelayState, as a user will + // likely initiate a completely new login flow, instead + // of reusing the same one. + + return db.Transaction(func(tx *storage.Connection) error { + return tx.Destroy(relayState) + }) +} + +func IsSAMLMetadataStale(idpMetadata *saml.EntityDescriptor, samlProvider models.SAMLProvider) bool { + now := time.Now() + + hasValidityExpired := !idpMetadata.ValidUntil.IsZero() && now.After(idpMetadata.ValidUntil) + hasCacheDurationExceeded := idpMetadata.CacheDuration != 0 && now.After(samlProvider.UpdatedAt.Add(idpMetadata.CacheDuration)) + + // if metadata XML does not publish validity or caching information, update once in 24 hours + needsForceUpdate := idpMetadata.ValidUntil.IsZero() && idpMetadata.CacheDuration == 0 && now.After(samlProvider.UpdatedAt.Add(24*time.Hour)) + + return hasValidityExpired || hasCacheDurationExceeded || needsForceUpdate +} + +func (a *API) SamlAcs(w http.ResponseWriter, r *http.Request) error { + if err := a.handleSamlAcs(w, r); err != nil { + u, uerr := url.Parse(a.config.SiteURL) + if uerr != nil { + return internalServerError("site url is improperly formattted").WithInternalError(err) + } + + q := getErrorQueryString(err, utilities.GetRequestID(r.Context()), observability.GetLogEntry(r).Entry, u.Query()) + u.RawQuery = q.Encode() + http.Redirect(w, r, u.String(), http.StatusSeeOther) + } + return nil +} + +// handleSamlAcs implements the main Assertion Consumer Service endpoint behavior. +func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + + db := a.db.WithContext(ctx) + config := a.config + log := observability.GetLogEntry(r).Entry + + relayStateValue := r.FormValue("RelayState") + relayStateUUID := uuid.FromStringOrNil(relayStateValue) + relayStateURL, _ := url.ParseRequestURI(relayStateValue) + + entityId := "" + initiatedBy := "" + redirectTo := "" + var requestIds []string + + var flowState *models.FlowState + if relayStateUUID != uuid.Nil { + // relay state is a valid UUID, therefore this is likely a SP initiated flow + + relayState, err := models.FindSAMLRelayStateByID(db, relayStateUUID) + if models.IsNotFoundError(err) { + return notFoundError(ErrorCodeSAMLRelayStateNotFound, "SAML RelayState does not exist, try logging in again?") + } else if err != nil { + return err + } + + if time.Since(relayState.CreatedAt) >= a.config.SAML.RelayStateValidityPeriod { + if err := a.samlDestroyRelayState(ctx, relayState); err != nil { + return internalServerError("SAML RelayState has expired and destroying it failed. Try logging in again?").WithInternalError(err) + } + + return unprocessableEntityError(ErrorCodeSAMLRelayStateExpired, "SAML RelayState has expired. Try logging in again?") + } + + // TODO: add abuse detection to bind the RelayState UUID with a + // HTTP-Only cookie + + ssoProvider, err := models.FindSSOProviderByID(db, relayState.SSOProviderID) + if err != nil { + return internalServerError("Unable to find SSO Provider from SAML RelayState") + } + + initiatedBy = "sp" + entityId = ssoProvider.SAMLProvider.EntityID + redirectTo = relayState.RedirectTo + requestIds = append(requestIds, relayState.RequestID) + if relayState.FlowState != nil { + flowState = relayState.FlowState + } + + if err := a.samlDestroyRelayState(ctx, relayState); err != nil { + return err + } + } else if relayStateValue == "" || relayStateURL != nil { + // RelayState may be a URL in which case it's the URL where the + // IdP is telling us to redirect the user to + + if r.FormValue("SAMLart") != "" { + // SAML Artifact responses are possible only when + // RelayState can be used to identify the Identity + // Provider. + return badRequestError(ErrorCodeValidationFailed, "SAML Artifact response can only be used with SP initiated flow") + } + + samlResponse := r.FormValue("SAMLResponse") + if samlResponse == "" { + return badRequestError(ErrorCodeValidationFailed, "SAMLResponse is missing") + } + + responseXML, err := base64.StdEncoding.DecodeString(samlResponse) + if err != nil { + return badRequestError(ErrorCodeValidationFailed, "SAMLResponse is not a valid Base64 string") + } + + var peekResponse saml.Response + err = xml.Unmarshal(responseXML, &peekResponse) + if err != nil { + return badRequestError(ErrorCodeValidationFailed, "SAMLResponse is not a valid XML SAML assertion").WithInternalError(err) + } + + initiatedBy = "idp" + entityId = peekResponse.Issuer.Value + redirectTo = relayStateValue + } else { + // RelayState can't be identified, so SAML flow can't continue + return badRequestError(ErrorCodeValidationFailed, "SAML RelayState is not a valid UUID or URL") + } + + ssoProvider, err := models.FindSAMLProviderByEntityID(db, entityId) + if models.IsNotFoundError(err) { + return notFoundError(ErrorCodeSAMLIdPNotFound, "A SAML connection has not been established with this Identity Provider") + } else if err != nil { + return err + } + + idpMetadata, err := ssoProvider.SAMLProvider.EntityDescriptor() + if err != nil { + return err + } + + samlMetadataModified := false + + if ssoProvider.SAMLProvider.MetadataURL == nil { + if !idpMetadata.ValidUntil.IsZero() && time.Until(idpMetadata.ValidUntil) <= (30*24*60)*time.Second { + logentry := log.WithField("sso_provider_id", ssoProvider.ID.String()) + logentry = logentry.WithField("expires_in", time.Until(idpMetadata.ValidUntil).String()) + logentry = logentry.WithField("valid_until", idpMetadata.ValidUntil) + logentry = logentry.WithField("saml_entity_id", ssoProvider.SAMLProvider.EntityID) + + logentry.Warn("SAML Metadata for identity provider will expire soon! Update its metadata_xml!") + } + } else if *ssoProvider.SAMLProvider.MetadataURL != "" && IsSAMLMetadataStale(idpMetadata, ssoProvider.SAMLProvider) { + rawMetadata, err := fetchSAMLMetadata(ctx, *ssoProvider.SAMLProvider.MetadataURL) + if err != nil { + // Fail silently but raise warning and continue with existing metadata + logentry := log.WithField("sso_provider_id", ssoProvider.ID.String()) + logentry = logentry.WithField("expires_in", time.Until(idpMetadata.ValidUntil).String()) + logentry = logentry.WithField("valid_until", idpMetadata.ValidUntil) + logentry = logentry.WithError(err) + logentry.Warn("SAML Metadata could not be retrieved, continuing with existing metadata") + } else { + ssoProvider.SAMLProvider.MetadataXML = string(rawMetadata) + samlMetadataModified = true + } + } + + serviceProvider := a.getSAMLServiceProvider(idpMetadata, initiatedBy == "idp") + spAssertion, err := serviceProvider.ParseResponse(r, requestIds) + if err != nil { + if ire, ok := err.(*saml.InvalidResponseError); ok { + return badRequestError(ErrorCodeValidationFailed, "SAML Assertion is not valid %s", ire.Response).WithInternalError(ire.PrivateErr) + } + + return badRequestError(ErrorCodeValidationFailed, "SAML Assertion is not valid").WithInternalError(err) + } + + assertion := SAMLAssertion{ + spAssertion, + } + + userID := assertion.UserID() + if userID == "" { + return badRequestError(ErrorCodeSAMLAssertionNoUserID, "SAML Assertion did not contain a persistent Subject Identifier attribute or Subject NameID uniquely identifying this user") + } + + claims := assertion.Process(ssoProvider.SAMLProvider.AttributeMapping) + + email, ok := claims["email"].(string) + if !ok || email == "" { + // mapping does not identify the email attribute, try to figure it out + email = assertion.Email() + } + + if email == "" { + return badRequestError(ErrorCodeSAMLAssertionNoEmail, "SAML Assertion does not contain an email address") + } else { + claims["email"] = email + } + + jsonClaims, err := json.Marshal(claims) + if err != nil { + return internalServerError("Mapped claims from provider could not be serialized into JSON").WithInternalError(err) + } + + providerClaims := &provider.Claims{} + if err := json.Unmarshal(jsonClaims, providerClaims); err != nil { + return internalServerError("Mapped claims from provider could not be deserialized from JSON").WithInternalError(err) + } + + providerClaims.Subject = userID + providerClaims.Issuer = ssoProvider.SAMLProvider.EntityID + providerClaims.Email = email + providerClaims.EmailVerified = true + + providerClaimsMap := structs.Map(providerClaims) + + // remove all of the parsed claims, so that the rest can go into CustomClaims + for key := range providerClaimsMap { + delete(claims, key) + } + + providerClaims.CustomClaims = claims + + var userProvidedData provider.UserProvidedData + + userProvidedData.Emails = append(userProvidedData.Emails, provider.Email{ + Email: email, + Verified: true, + Primary: true, + }) + + // userProvidedData.Provider.Type = "saml" + // userProvidedData.Provider.ID = ssoProvider.ID.String() + // userProvidedData.Provider.SAMLEntityID = ssoProvider.SAMLProvider.EntityID + // userProvidedData.Provider.SAMLInitiatedBy = initiatedBy + + userProvidedData.Metadata = providerClaims + + // TODO: below + // refreshTokenParams.SSOProviderID = ssoProvider.ID + // refreshTokenParams.InitiatedByProvider = initiatedBy == "idp" + // refreshTokenParams.NotBefore = assertion.NotBefore() + // refreshTokenParams.NotAfter = assertion.NotAfter() + + notAfter := assertion.NotAfter() + + var grantParams models.GrantParams + + grantParams.FillGrantParams(r) + + if !notAfter.IsZero() { + grantParams.SessionNotAfter = ¬After + } + + var token *AccessTokenResponse + if samlMetadataModified { + if err := db.UpdateColumns(&ssoProvider.SAMLProvider, "metadata_xml", "updated_at"); err != nil { + return err + } + } + + if err := db.Transaction(func(tx *storage.Connection) error { + var terr error + var user *models.User + + // accounts potentially created via SAML can contain non-unique email addresses in the auth.users table + if user, terr = a.createAccountFromExternalIdentity(tx, r, &userProvidedData, "sso:"+ssoProvider.ID.String()); terr != nil { + return terr + } + if flowState != nil { + // This means that the callback is using PKCE + flowState.UserID = &(user.ID) + if terr := tx.Update(flowState); terr != nil { + return terr + } + } + + token, terr = a.issueRefreshToken(r, tx, user, models.SSOSAML, grantParams) + + if terr != nil { + return internalServerError("Unable to issue refresh token from SAML Assertion").WithInternalError(terr) + } + + return nil + }); err != nil { + return err + } + + if !utilities.IsRedirectURLValid(config, redirectTo) { + redirectTo = config.SiteURL + } + if flowState != nil { + // This means that the callback is using PKCE + // Set the flowState.AuthCode to the query param here + redirectTo, err = a.prepPKCERedirectURL(redirectTo, flowState.AuthCode) + if err != nil { + return err + } + http.Redirect(w, r, redirectTo, http.StatusFound) + return nil + + } + http.Redirect(w, r, token.AsRedirectURL(redirectTo, url.Values{}), http.StatusFound) + + return nil +} diff --git a/auth_v2.169.0/internal/api/samlassertion.go b/auth_v2.169.0/internal/api/samlassertion.go new file mode 100644 index 0000000..fdf9323 --- /dev/null +++ b/auth_v2.169.0/internal/api/samlassertion.go @@ -0,0 +1,188 @@ +package api + +import ( + "strings" + "time" + + "github.com/crewjam/saml" + "github.com/supabase/auth/internal/models" +) + +type SAMLAssertion struct { + *saml.Assertion +} + +const ( + SAMLSubjectIDAttributeName = "urn:oasis:names:tc:SAML:attribute:subject-id" +) + +// Attribute returns the first matching attribute value in the attribute +// statements where name equals the official SAML attribute Name or +// FriendlyName. Returns nil if such an attribute can't be found. +func (a *SAMLAssertion) Attribute(name string) []saml.AttributeValue { + var values []saml.AttributeValue + + for _, stmt := range a.AttributeStatements { + for _, attr := range stmt.Attributes { + if strings.EqualFold(attr.Name, name) || strings.EqualFold(attr.FriendlyName, name) { + values = append(values, attr.Values...) + } + } + } + + return values +} + +// UserID returns the best choice for a persistent user identifier on the +// Identity Provider side. Don't assume the format of the string returned, as +// it's Identity Provider specific. +func (a *SAMLAssertion) UserID() string { + // First we look up the SAMLSubjectIDAttributeName in the attribute + // section of the assertion, as this is the preferred way to + // persistently identify users in SAML 2.0. + // See: https://docs.oasis-open.org/security/saml-subject-id-attr/v1.0/cs01/saml-subject-id-attr-v1.0-cs01.html#_Toc536097226 + values := a.Attribute(SAMLSubjectIDAttributeName) + if len(values) > 0 { + return values[0].Value + } + + // Otherwise, fall back to the SubjectID value. + subjectID, isPersistent := a.SubjectID() + if !isPersistent { + return "" + } + + return subjectID +} + +// SubjectID returns the user identifier in present in the Subject section of +// the SAML assertion. Note that this way of identifying the Subject is +// generally superseded by the SAMLSubjectIDAttributeName assertion attribute; +// tho must be present in all assertions. It can have a few formats, of which +// the most important are: saml.EmailAddressNameIDFormat (meaning the user ID +// is an email address), saml.PersistentNameIDFormat (the user ID is an opaque +// string that does not change with each assertion, e.g. UUID), +// saml.TransientNameIDFormat (the user ID changes with each assertion -- can't +// be used to identify a user). The boolean returned identifies if the user ID +// is persistent. If it's an email address, it's lowercased just in case. +func (a *SAMLAssertion) SubjectID() (string, bool) { + if a.Subject == nil { + return "", false + } + + if a.Subject.NameID == nil { + return "", false + } + + if a.Subject.NameID.Value == "" { + return "", false + } + + if a.Subject.NameID.Format == string(saml.EmailAddressNameIDFormat) { + return strings.ToLower(strings.TrimSpace(a.Subject.NameID.Value)), true + } + + // all other NameID formats are regarded as persistent + isPersistent := a.Subject.NameID.Format != string(saml.TransientNameIDFormat) + + return a.Subject.NameID.Value, isPersistent +} + +// Email returns the best guess for an email address. +func (a *SAMLAssertion) Email() string { + attributeNames := []string{ + "urn:oid:0.9.2342.19200300.100.1.3", + "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress", + "http://schemas.xmlsoap.org/claims/EmailAddress", + "mail", + "Mail", + "email", + } + + for _, name := range attributeNames { + for _, attr := range a.Attribute(name) { + if attr.Value != "" { + return attr.Value + } + } + } + + if a.Subject.NameID.Format == string(saml.EmailAddressNameIDFormat) { + return a.Subject.NameID.Value + } + + return "" +} + +// Process processes this assertion according to the SAMLAttributeMapping. Never returns nil. +func (a *SAMLAssertion) Process(mapping models.SAMLAttributeMapping) map[string]interface{} { + ret := make(map[string]interface{}) + + for key, mapper := range mapping.Keys { + names := []string{} + if mapper.Name != "" { + names = append(names, mapper.Name) + } + names = append(names, mapper.Names...) + + setKey := false + + for _, name := range names { + for _, attr := range a.Attribute(name) { + if attr.Value != "" { + setKey = true + + if mapper.Array { + if ret[key] == nil { + ret[key] = []string{} + } + + ret[key] = append(ret[key].([]string), attr.Value) + } else { + ret[key] = attr.Value + break + } + } + } + + if setKey { + break + } + } + + if !setKey && mapper.Default != nil { + ret[key] = mapper.Default + } + } + + return ret +} + +// NotBefore extracts the time before which this assertion should not be +// considered. +func (a *SAMLAssertion) NotBefore() time.Time { + if a.Conditions != nil && !a.Conditions.NotBefore.IsZero() { + return a.Conditions.NotBefore.UTC() + } + + return time.Time{} +} + +// NotAfter extracts the time at which or after this assertion should not be +// considered. +func (a *SAMLAssertion) NotAfter() time.Time { + var notOnOrAfter time.Time + + for _, statement := range a.AuthnStatements { + if statement.SessionNotOnOrAfter == nil { + continue + } + + notOnOrAfter = *statement.SessionNotOnOrAfter + if !notOnOrAfter.IsZero() { + break + } + } + + return notOnOrAfter +} diff --git a/auth_v2.169.0/internal/api/samlassertion_test.go b/auth_v2.169.0/internal/api/samlassertion_test.go new file mode 100644 index 0000000..b7461b2 --- /dev/null +++ b/auth_v2.169.0/internal/api/samlassertion_test.go @@ -0,0 +1,347 @@ +package api + +import ( + tst "testing" + + "encoding/xml" + + "github.com/crewjam/saml" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/models" +) + +func TestSAMLAssertionUserID(t *tst.T) { + type spec struct { + xml string + userID string + } + + examples := []spec{ + { + xml: ` + + https://example.com/saml + + + transient-name-id + + + + + + + http://localhost:9999/saml/metadata + + + + + + urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport + + + + + +`, + userID: "", + }, + { + xml: ` + + https://example.com/saml + + + persistent-name-id + + + + + + + http://localhost:9999/saml/metadata + + + + + + urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport + + + + + +`, + userID: "persistent-name-id", + }, + { + xml: ` + + https://example.com/saml + + + name-id@example.com + + + + + + + http://localhost:9999/saml/metadata + + + + + + urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport + + + + + +`, + userID: "name-id@example.com", + }, + { + xml: ` + + https://example.com/saml + + + name-id@example.com + + + + + + + http://localhost:9999/saml/metadata + + + + + + urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport + + + + + subject-id + + + +`, + userID: "subject-id", + }, + } + + for i, example := range examples { + rawAssertion := saml.Assertion{} + require.NoError(t, xml.Unmarshal([]byte(example.xml), &rawAssertion)) + + assertion := SAMLAssertion{ + &rawAssertion, + } + + userID := assertion.UserID() + + require.Equal(t, userID, example.userID, "example %d had different user ID", i) + } +} + +func TestSAMLAssertionProcessing(t *tst.T) { + type spec struct { + desc string + xml string + mapping models.SAMLAttributeMapping + expected map[string]interface{} + } + + examples := []spec{ + { + desc: "valid attribute and mapping", + xml: ` + + + + someone@example.com + + + + `, + mapping: models.SAMLAttributeMapping{ + Keys: map[string]models.SAMLAttribute{ + "email": { + Name: "mail", + }, + }, + }, + expected: map[string]interface{}{ + "email": "someone@example.com", + }, + }, + { + desc: "valid attributes, use first attribute found in Names", + xml: ` + + + + old-soap@example.com + + + soap@example.com + + + + `, + mapping: models.SAMLAttributeMapping{ + Keys: map[string]models.SAMLAttribute{ + "email": { + Names: []string{ + "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress", + "http://schemas.xmlsoap.org/claims/EmailAddress", + }, + }, + }, + }, + expected: map[string]interface{}{ + "email": "soap@example.com", + }, + }, + { + desc: "valid groups attribute", + xml: ` + + + + group1 + group2 + + + soap@example.com + + + + `, + mapping: models.SAMLAttributeMapping{ + Keys: map[string]models.SAMLAttribute{ + "email": { + Names: []string{ + "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/emailaddress", + "http://schemas.xmlsoap.org/claims/EmailAddress", + }, + }, + "groups": { + Name: "groups", + Array: true, + }, + }, + }, + expected: map[string]interface{}{ + "email": "soap@example.com", + "groups": []string{ + "group1", + "group2", + }, + }, + }, + { + desc: "missing attribute use default value", + xml: ` + + + + someone@example.com + + + +`, + mapping: models.SAMLAttributeMapping{ + Keys: map[string]models.SAMLAttribute{ + "email": { + Name: "email", + }, + "role": { + Default: "member", + }, + }, + }, + expected: map[string]interface{}{ + "email": "someone@example.com", + "role": "member", + }, + }, + { + desc: "use default value even if attribute exists but is not specified in mapping", + xml: ` + + + + someone@example.com + + + admin + + + + `, + mapping: models.SAMLAttributeMapping{ + Keys: map[string]models.SAMLAttribute{ + "email": { + Name: "mail", + }, + "role": { + Default: "member", + }, + }, + }, + expected: map[string]interface{}{ + "email": "someone@example.com", + "role": "member", + }, + }, + { + desc: "use value in XML when attribute exists and is specified in mapping", + xml: ` + + + + someone@example.com + + + admin + + + + `, + mapping: models.SAMLAttributeMapping{ + Keys: map[string]models.SAMLAttribute{ + "email": { + Name: "mail", + }, + "role": { + Name: "role", + Default: "member", + }, + }, + }, + expected: map[string]interface{}{ + "email": "someone@example.com", + "role": "admin", + }, + }, + } + + for i, example := range examples { + t.Run(example.desc, func(t *tst.T) { + rawAssertion := saml.Assertion{} + require.NoError(t, xml.Unmarshal([]byte(example.xml), &rawAssertion)) + + assertion := SAMLAssertion{ + &rawAssertion, + } + + result := assertion.Process(example.mapping) + require.Equal(t, example.expected, result, "example %d had different processing", i) + }) + } +} diff --git a/auth_v2.169.0/internal/api/settings.go b/auth_v2.169.0/internal/api/settings.go new file mode 100644 index 0000000..bc2f386 --- /dev/null +++ b/auth_v2.169.0/internal/api/settings.go @@ -0,0 +1,79 @@ +package api + +import "net/http" + +type ProviderSettings struct { + AnonymousUsers bool `json:"anonymous_users"` + Apple bool `json:"apple"` + Azure bool `json:"azure"` + Bitbucket bool `json:"bitbucket"` + Discord bool `json:"discord"` + Facebook bool `json:"facebook"` + Figma bool `json:"figma"` + Fly bool `json:"fly"` + GitHub bool `json:"github"` + GitLab bool `json:"gitlab"` + Google bool `json:"google"` + Keycloak bool `json:"keycloak"` + Kakao bool `json:"kakao"` + Linkedin bool `json:"linkedin"` + LinkedinOIDC bool `json:"linkedin_oidc"` + Notion bool `json:"notion"` + Spotify bool `json:"spotify"` + Slack bool `json:"slack"` + SlackOIDC bool `json:"slack_oidc"` + WorkOS bool `json:"workos"` + Twitch bool `json:"twitch"` + Twitter bool `json:"twitter"` + Email bool `json:"email"` + Phone bool `json:"phone"` + Zoom bool `json:"zoom"` +} + +type Settings struct { + ExternalProviders ProviderSettings `json:"external"` + DisableSignup bool `json:"disable_signup"` + MailerAutoconfirm bool `json:"mailer_autoconfirm"` + PhoneAutoconfirm bool `json:"phone_autoconfirm"` + SmsProvider string `json:"sms_provider"` + SAMLEnabled bool `json:"saml_enabled"` +} + +func (a *API) Settings(w http.ResponseWriter, r *http.Request) error { + config := a.config + + return sendJSON(w, http.StatusOK, &Settings{ + ExternalProviders: ProviderSettings{ + AnonymousUsers: config.External.AnonymousUsers.Enabled, + Apple: config.External.Apple.Enabled, + Azure: config.External.Azure.Enabled, + Bitbucket: config.External.Bitbucket.Enabled, + Discord: config.External.Discord.Enabled, + Facebook: config.External.Facebook.Enabled, + Figma: config.External.Figma.Enabled, + Fly: config.External.Fly.Enabled, + GitHub: config.External.Github.Enabled, + GitLab: config.External.Gitlab.Enabled, + Google: config.External.Google.Enabled, + Kakao: config.External.Kakao.Enabled, + Keycloak: config.External.Keycloak.Enabled, + Linkedin: config.External.Linkedin.Enabled, + LinkedinOIDC: config.External.LinkedinOIDC.Enabled, + Notion: config.External.Notion.Enabled, + Spotify: config.External.Spotify.Enabled, + Slack: config.External.Slack.Enabled, + SlackOIDC: config.External.SlackOIDC.Enabled, + Twitch: config.External.Twitch.Enabled, + Twitter: config.External.Twitter.Enabled, + WorkOS: config.External.WorkOS.Enabled, + Email: config.External.Email.Enabled, + Phone: config.External.Phone.Enabled, + Zoom: config.External.Zoom.Enabled, + }, + DisableSignup: config.DisableSignup, + MailerAutoconfirm: config.Mailer.Autoconfirm, + PhoneAutoconfirm: config.Sms.Autoconfirm, + SmsProvider: config.Sms.Provider, + SAMLEnabled: config.SAML.Enabled, + }) +} diff --git a/auth_v2.169.0/internal/api/settings_test.go b/auth_v2.169.0/internal/api/settings_test.go new file mode 100644 index 0000000..767bcf7 --- /dev/null +++ b/auth_v2.169.0/internal/api/settings_test.go @@ -0,0 +1,73 @@ +package api + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSettings_DefaultProviders(t *testing.T) { + api, _, err := setupAPIForTest() + require.NoError(t, err) + + // Setup request + req := httptest.NewRequest(http.MethodGet, "http://localhost/settings", nil) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + api.handler.ServeHTTP(w, req) + require.Equal(t, w.Code, http.StatusOK) + resp := Settings{} + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + + p := resp.ExternalProviders + + require.False(t, p.Phone) + require.True(t, p.Email) + require.True(t, p.Azure) + require.True(t, p.Bitbucket) + require.True(t, p.Discord) + require.True(t, p.Facebook) + require.True(t, p.Notion) + require.True(t, p.Spotify) + require.True(t, p.Slack) + require.True(t, p.SlackOIDC) + require.True(t, p.Google) + require.True(t, p.Kakao) + require.True(t, p.Keycloak) + require.True(t, p.Linkedin) + require.True(t, p.LinkedinOIDC) + require.True(t, p.GitHub) + require.True(t, p.GitLab) + require.True(t, p.Twitch) + require.True(t, p.WorkOS) + require.True(t, p.Zoom) + +} + +func TestSettings_EmailDisabled(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + config.External.Email.Enabled = false + + // Setup request + req := httptest.NewRequest(http.MethodGet, "http://localhost/settings", nil) + req.Header.Set("Content-Type", "application/json") + + ctx := context.Background() + req = req.WithContext(ctx) + + w := httptest.NewRecorder() + api.handler.ServeHTTP(w, req) + require.Equal(t, w.Code, http.StatusOK) + resp := Settings{} + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + + p := resp.ExternalProviders + require.False(t, p.Email) +} diff --git a/auth_v2.169.0/internal/api/signup.go b/auth_v2.169.0/internal/api/signup.go new file mode 100644 index 0000000..1c74da6 --- /dev/null +++ b/auth_v2.169.0/internal/api/signup.go @@ -0,0 +1,390 @@ +package api + +import ( + "context" + "net/http" + "time" + + "github.com/fatih/structs" + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/api/sms_provider" + "github.com/supabase/auth/internal/metering" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +// SignupParams are the parameters the Signup endpoint accepts +type SignupParams struct { + Email string `json:"email"` + Phone string `json:"phone"` + Password string `json:"password"` + Data map[string]interface{} `json:"data"` + Provider string `json:"-"` + Aud string `json:"-"` + Channel string `json:"channel"` + CodeChallengeMethod string `json:"code_challenge_method"` + CodeChallenge string `json:"code_challenge"` +} + +func (a *API) validateSignupParams(ctx context.Context, p *SignupParams) error { + config := a.config + + if p.Password == "" { + return badRequestError(ErrorCodeValidationFailed, "Signup requires a valid password") + } + + if err := a.checkPasswordStrength(ctx, p.Password); err != nil { + return err + } + if p.Email != "" && p.Phone != "" { + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided on signup.") + } + if p.Provider == "phone" && !sms_provider.IsValidMessageChannel(p.Channel, config) { + return badRequestError(ErrorCodeValidationFailed, InvalidChannelError) + } + // PKCE not needed as phone signups already return access token in body + if p.Phone != "" && p.CodeChallenge != "" { + return badRequestError(ErrorCodeValidationFailed, "PKCE not supported for phone signups") + } + if err := validatePKCEParams(p.CodeChallengeMethod, p.CodeChallenge); err != nil { + return err + } + + return nil +} + +func (p *SignupParams) ConfigureDefaults() { + if p.Email != "" { + p.Provider = "email" + } else if p.Phone != "" { + p.Provider = "phone" + } + if p.Data == nil { + p.Data = make(map[string]interface{}) + } + + // For backwards compatibility, we default to SMS if params Channel is not specified + if p.Phone != "" && p.Channel == "" { + p.Channel = sms_provider.SMSProvider + } +} + +func (params *SignupParams) ToUserModel(isSSOUser bool) (user *models.User, err error) { + switch params.Provider { + case "email": + user, err = models.NewUser("", params.Email, params.Password, params.Aud, params.Data) + case "phone": + user, err = models.NewUser(params.Phone, "", params.Password, params.Aud, params.Data) + case "anonymous": + user, err = models.NewUser("", "", "", params.Aud, params.Data) + user.IsAnonymous = true + default: + // handles external provider case + user, err = models.NewUser("", params.Email, params.Password, params.Aud, params.Data) + } + if err != nil { + err = internalServerError("Database error creating user").WithInternalError(err) + return + } + user.IsSSOUser = isSSOUser + if user.AppMetaData == nil { + user.AppMetaData = make(map[string]interface{}) + } + + user.Identities = make([]models.Identity, 0) + + if params.Provider != "anonymous" { + // TODO: Deprecate "provider" field + user.AppMetaData["provider"] = params.Provider + + user.AppMetaData["providers"] = []string{params.Provider} + } + + return user, nil +} + +// Signup is the endpoint for registering a new user +func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + config := a.config + db := a.db.WithContext(ctx) + + if config.DisableSignup { + return unprocessableEntityError(ErrorCodeSignupDisabled, "Signups not allowed for this instance") + } + + params := &SignupParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + params.ConfigureDefaults() + + if err := a.validateSignupParams(ctx, params); err != nil { + return err + } + + var err error + flowType := getFlowFromChallenge(params.CodeChallenge) + + var user *models.User + var grantParams models.GrantParams + + grantParams.FillGrantParams(r) + + params.Aud = a.requestAud(ctx, r) + + switch params.Provider { + case "email": + if !config.External.Email.Enabled { + return badRequestError(ErrorCodeEmailProviderDisabled, "Email signups are disabled") + } + params.Email, err = a.validateEmail(params.Email) + if err != nil { + return err + } + user, err = models.IsDuplicatedEmail(db, params.Email, params.Aud, nil) + case "phone": + if !config.External.Phone.Enabled { + return badRequestError(ErrorCodePhoneProviderDisabled, "Phone signups are disabled") + } + params.Phone, err = validatePhone(params.Phone) + if err != nil { + return err + } + user, err = models.FindUserByPhoneAndAudience(db, params.Phone, params.Aud) + default: + msg := "" + if config.External.Email.Enabled && config.External.Phone.Enabled { + msg = "Sign up only available with email or phone provider" + } else if config.External.Email.Enabled { + msg = "Sign up only available with email provider" + } else if config.External.Phone.Enabled { + msg = "Sign up only available with phone provider" + } else { + msg = "Sign up with this provider not possible" + } + + return badRequestError(ErrorCodeValidationFailed, msg) + } + + if err != nil && !models.IsNotFoundError(err) { + return internalServerError("Database error finding user").WithInternalError(err) + } + + var signupUser *models.User + if user == nil { + // always call this outside of a database transaction as this method + // can be computationally hard and block due to password hashing + signupUser, err = params.ToUserModel(false /* <- isSSOUser */) + if err != nil { + return err + } + } + + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + if user != nil { + if (params.Provider == "email" && user.IsConfirmed()) || (params.Provider == "phone" && user.IsPhoneConfirmed()) { + return UserExistsError + } + // do not update the user because we can't be sure of their claimed identity + } else { + user, terr = a.signupNewUser(tx, signupUser) + if terr != nil { + return terr + } + } + identity, terr := models.FindIdentityByIdAndProvider(tx, user.ID.String(), params.Provider) + if terr != nil { + if !models.IsNotFoundError(terr) { + return terr + } + identityData := structs.Map(provider.Claims{ + Subject: user.ID.String(), + Email: user.GetEmail(), + }) + for k, v := range params.Data { + if _, ok := identityData[k]; !ok { + identityData[k] = v + } + } + identity, terr = a.createNewIdentity(tx, user, params.Provider, identityData) + if terr != nil { + return terr + } + if terr := user.RemoveUnconfirmedIdentities(tx, identity); terr != nil { + return terr + } + } + user.Identities = []models.Identity{*identity} + + if params.Provider == "email" && !user.IsConfirmed() { + if config.Mailer.Autoconfirm { + if terr = models.NewAuditLogEntry(r, tx, user, models.UserSignedUpAction, "", map[string]interface{}{ + "provider": params.Provider, + }); terr != nil { + return terr + } + if terr = user.Confirm(tx); terr != nil { + return internalServerError("Database error updating user").WithInternalError(terr) + } + } else { + if terr = models.NewAuditLogEntry(r, tx, user, models.UserConfirmationRequestedAction, "", map[string]interface{}{ + "provider": params.Provider, + }); terr != nil { + return terr + } + if isPKCEFlow(flowType) { + _, terr := generateFlowState(tx, params.Provider, models.EmailSignup, params.CodeChallengeMethod, params.CodeChallenge, &user.ID) + if terr != nil { + return terr + } + } + if terr = a.sendConfirmation(r, tx, user, flowType); terr != nil { + return terr + } + } + } else if params.Provider == "phone" && !user.IsPhoneConfirmed() { + if config.Sms.Autoconfirm { + if terr = models.NewAuditLogEntry(r, tx, user, models.UserSignedUpAction, "", map[string]interface{}{ + "provider": params.Provider, + "channel": params.Channel, + }); terr != nil { + return terr + } + if terr = user.ConfirmPhone(tx); terr != nil { + return internalServerError("Database error updating user").WithInternalError(terr) + } + } else { + if terr = models.NewAuditLogEntry(r, tx, user, models.UserConfirmationRequestedAction, "", map[string]interface{}{ + "provider": params.Provider, + }); terr != nil { + return terr + } + if _, terr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneConfirmationOtp, params.Channel); terr != nil { + return terr + } + } + } + + return nil + }) + + if err != nil { + if errors.Is(err, UserExistsError) { + err = db.Transaction(func(tx *storage.Connection) error { + if terr := models.NewAuditLogEntry(r, tx, user, models.UserRepeatedSignUpAction, "", map[string]interface{}{ + "provider": params.Provider, + }); terr != nil { + return terr + } + return nil + }) + if err != nil { + return err + } + if config.Mailer.Autoconfirm || config.Sms.Autoconfirm { + return unprocessableEntityError(ErrorCodeUserAlreadyExists, "User already registered") + } + sanitizedUser, err := sanitizeUser(user, params) + if err != nil { + return err + } + return sendJSON(w, http.StatusOK, sanitizedUser) + } + return err + } + + // handles case where Mailer.Autoconfirm is true or Phone.Autoconfirm is true + if user.IsConfirmed() || user.IsPhoneConfirmed() { + var token *AccessTokenResponse + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + if terr = models.NewAuditLogEntry(r, tx, user, models.LoginAction, "", map[string]interface{}{ + "provider": params.Provider, + }); terr != nil { + return terr + } + token, terr = a.issueRefreshToken(r, tx, user, models.PasswordGrant, grantParams) + + if terr != nil { + return terr + } + return nil + }) + if err != nil { + return err + } + metering.RecordLogin("password", user.ID) + return sendJSON(w, http.StatusOK, token) + } + if user.HasBeenInvited() { + // Remove sensitive fields + user.UserMetaData = map[string]interface{}{} + user.Identities = []models.Identity{} + } + return sendJSON(w, http.StatusOK, user) +} + +// sanitizeUser removes all user sensitive information from the user object +// Should be used whenever we want to prevent information about whether a user is registered or not from leaking +func sanitizeUser(u *models.User, params *SignupParams) (*models.User, error) { + now := time.Now() + + u.ID = uuid.Must(uuid.NewV4()) + + u.Role, u.EmailChange = "", "" + u.CreatedAt, u.UpdatedAt, u.ConfirmationSentAt = now, now, &now + u.LastSignInAt, u.ConfirmedAt, u.EmailChangeSentAt, u.EmailConfirmedAt, u.PhoneConfirmedAt = nil, nil, nil, nil, nil + u.Identities = make([]models.Identity, 0) + u.UserMetaData = params.Data + u.Aud = params.Aud + + // sanitize app_metadata + u.AppMetaData = map[string]interface{}{ + "provider": params.Provider, + "providers": []string{params.Provider}, + } + + // sanitize param fields + switch params.Provider { + case "email": + u.Phone = "" + case "phone": + u.Email = "" + default: + u.Phone, u.Email = "", "" + } + + return u, nil +} + +func (a *API) signupNewUser(conn *storage.Connection, user *models.User) (*models.User, error) { + config := a.config + + err := conn.Transaction(func(tx *storage.Connection) error { + var terr error + if terr = tx.Create(user); terr != nil { + return internalServerError("Database error saving new user").WithInternalError(terr) + } + if terr = user.SetRole(tx, config.JWT.DefaultGroupName); terr != nil { + return internalServerError("Database error updating user").WithInternalError(terr) + } + return nil + }) + if err != nil { + return nil, err + } + + // there may be triggers or generated column values in the database that will modify the + // user data as it is being inserted. thus we load the user object + // again to fetch those changes. + if err := conn.Reload(user); err != nil { + return nil, internalServerError("Database error loading user after sign-up").WithInternalError(err) + } + + return user, nil +} diff --git a/auth_v2.169.0/internal/api/signup_test.go b/auth_v2.169.0/internal/api/signup_test.go new file mode 100644 index 0000000..3f47832 --- /dev/null +++ b/auth_v2.169.0/internal/api/signup_test.go @@ -0,0 +1,153 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + mail "github.com/supabase/auth/internal/mailer" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +type SignupTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestSignup(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &SignupTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *SignupTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) +} + +// TestSignup tests API /signup route +func (ts *SignupTestSuite) TestSignup() { + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + "password": "test123", + "data": map[string]interface{}{ + "a": 1, + }, + })) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "/signup", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusOK, w.Code) + + data := models.User{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + assert.Equal(ts.T(), "test@example.com", data.GetEmail()) + assert.Equal(ts.T(), ts.Config.JWT.Aud, data.Aud) + assert.Equal(ts.T(), 1.0, data.UserMetaData["a"]) + assert.Equal(ts.T(), "email", data.AppMetaData["provider"]) + assert.Equal(ts.T(), []interface{}{"email"}, data.AppMetaData["providers"]) +} + +// TestSignupTwice checks to make sure the same email cannot be registered twice +func (ts *SignupTestSuite) TestSignupTwice() { + // Request body + var buffer bytes.Buffer + + encode := func() { + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test1@example.com", + "password": "test123", + "data": map[string]interface{}{ + "a": 1, + }, + })) + } + + encode() + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/signup", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + y := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(y, req) + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test1@example.com", ts.Config.JWT.Aud) + if err == nil { + require.NoError(ts.T(), u.Confirm(ts.API.db)) + } + + encode() + ts.API.handler.ServeHTTP(w, req) + + data := models.User{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + require.Equal(ts.T(), http.StatusOK, w.Code) + + assert.NotEqual(ts.T(), u.ID, data.ID) + assert.Equal(ts.T(), "test1@example.com", data.GetEmail()) + assert.Equal(ts.T(), ts.Config.JWT.Aud, data.Aud) + assert.Equal(ts.T(), 1.0, data.UserMetaData["a"]) + assert.Equal(ts.T(), "email", data.AppMetaData["provider"]) + assert.Equal(ts.T(), []interface{}{"email"}, data.AppMetaData["providers"]) +} + +func (ts *SignupTestSuite) TestVerifySignup() { + user, err := models.NewUser("123456789", "test@example.com", "testing", ts.Config.JWT.Aud, nil) + user.ConfirmationToken = "asdf3" + now := time.Now() + user.ConfirmationSentAt = &now + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(user)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, user.ID, user.GetEmail(), user.ConfirmationToken, models.ConfirmationToken)) + + // Find test user + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + // Setup request + reqUrl := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.SignupVerification, u.ConfirmationToken) + req := httptest.NewRequest(http.MethodGet, reqUrl, nil) + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusSeeOther, w.Code) + + urlVal, err := url.Parse(w.Result().Header.Get("Location")) + require.NoError(ts.T(), err) + v, err := url.ParseQuery(urlVal.Fragment) + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), v.Get("access_token")) + require.NotEmpty(ts.T(), v.Get("expires_in")) + require.NotEmpty(ts.T(), v.Get("refresh_token")) +} diff --git a/auth_v2.169.0/internal/api/sms_provider/messagebird.go b/auth_v2.169.0/internal/api/sms_provider/messagebird.go new file mode 100644 index 0000000..05f7939 --- /dev/null +++ b/auth_v2.169.0/internal/api/sms_provider/messagebird.go @@ -0,0 +1,115 @@ +package sms_provider + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/utilities" +) + +const ( + defaultMessagebirdApiBase = "https://rest.messagebird.com" +) + +type MessagebirdProvider struct { + Config *conf.MessagebirdProviderConfiguration + APIPath string +} + +type MessagebirdResponseRecipients struct { + TotalSentCount int `json:"totalSentCount"` +} + +type MessagebirdResponse struct { + ID string `json:"id"` + Recipients MessagebirdResponseRecipients `json:"recipients"` +} + +type MessagebirdError struct { + Code int `json:"code"` + Description string `json:"description"` + Parameter string `json:"parameter"` +} + +type MessagebirdErrResponse struct { + Errors []MessagebirdError `json:"errors"` +} + +func (t MessagebirdErrResponse) Error() string { + return t.Errors[0].Description +} + +// Creates a SmsProvider with the Messagebird Config +func NewMessagebirdProvider(config conf.MessagebirdProviderConfiguration) (SmsProvider, error) { + if err := config.Validate(); err != nil { + return nil, err + } + + apiPath := defaultMessagebirdApiBase + "/messages" + return &MessagebirdProvider{ + Config: &config, + APIPath: apiPath, + }, nil +} + +func (t *MessagebirdProvider) SendMessage(phone, message, channel, otp string) (string, error) { + switch channel { + case SMSProvider: + return t.SendSms(phone, message) + default: + return "", fmt.Errorf("channel type %q is not supported for Messagebird", channel) + } +} + +// Send an SMS containing the OTP with Messagebird's API +func (t *MessagebirdProvider) SendSms(phone string, message string) (string, error) { + body := url.Values{ + "originator": {t.Config.Originator}, + "body": {message}, + "recipients": {phone}, + "type": {"sms"}, + "datacoding": {"unicode"}, + } + + client := &http.Client{Timeout: defaultTimeout} + r, err := http.NewRequest("POST", t.APIPath, strings.NewReader(body.Encode())) + if err != nil { + return "", err + } + r.Header.Add("Content-Type", "application/x-www-form-urlencoded") + r.Header.Add("Authorization", "AccessKey "+t.Config.AccessKey) + res, err := client.Do(r) + if err != nil { + return "", err + } + + if res.StatusCode == http.StatusBadRequest || res.StatusCode == http.StatusForbidden || res.StatusCode == http.StatusUnauthorized || res.StatusCode == http.StatusUnprocessableEntity { + resp := &MessagebirdErrResponse{} + if err := json.NewDecoder(res.Body).Decode(resp); err != nil { + return "", err + } + return "", resp + } + defer utilities.SafeClose(res.Body) + + // validate sms status + resp := &MessagebirdResponse{} + derr := json.NewDecoder(res.Body).Decode(resp) + if derr != nil { + return "", derr + } + + if resp.Recipients.TotalSentCount == 0 { + return "", fmt.Errorf("messagebird error: total sent count is 0") + } + + return resp.ID, nil +} + +func (t *MessagebirdProvider) VerifyOTP(phone, code string) error { + return fmt.Errorf("VerifyOTP is not supported for Messagebird") +} diff --git a/auth_v2.169.0/internal/api/sms_provider/sms_provider.go b/auth_v2.169.0/internal/api/sms_provider/sms_provider.go new file mode 100644 index 0000000..103db4f --- /dev/null +++ b/auth_v2.169.0/internal/api/sms_provider/sms_provider.go @@ -0,0 +1,70 @@ +package sms_provider + +import ( + "fmt" + "log" + "os" + "time" + + "github.com/supabase/auth/internal/conf" +) + +// overrides the SmsProvider set to always return the mock provider +var MockProvider SmsProvider = nil + +var defaultTimeout time.Duration = time.Second * 10 + +const SMSProvider = "sms" +const WhatsappProvider = "whatsapp" + +func init() { + timeoutStr := os.Getenv("GOTRUE_INTERNAL_HTTP_TIMEOUT") + if timeoutStr != "" { + if timeout, err := time.ParseDuration(timeoutStr); err != nil { + log.Fatalf("error loading GOTRUE_INTERNAL_HTTP_TIMEOUT: %v", err.Error()) + } else if timeout != 0 { + defaultTimeout = timeout + } + } +} + +type SmsProvider interface { + SendMessage(phone, message, channel, otp string) (string, error) + VerifyOTP(phone, token string) error +} + +func GetSmsProvider(config conf.GlobalConfiguration) (SmsProvider, error) { + if MockProvider != nil { + return MockProvider, nil + } + + switch name := config.Sms.Provider; name { + case "twilio": + return NewTwilioProvider(config.Sms.Twilio) + case "messagebird": + return NewMessagebirdProvider(config.Sms.Messagebird) + case "textlocal": + return NewTextlocalProvider(config.Sms.Textlocal) + case "vonage": + return NewVonageProvider(config.Sms.Vonage) + case "twilio_verify": + return NewTwilioVerifyProvider(config.Sms.TwilioVerify) + default: + return nil, fmt.Errorf("sms Provider %s could not be found", name) + } +} + +func IsValidMessageChannel(channel string, config *conf.GlobalConfiguration) bool { + if config.Hook.SendSMS.Enabled { + // channel doesn't matter if SMS hook is enabled + return true + } + switch channel { + case SMSProvider: + return true + case WhatsappProvider: + return config.Sms.Provider == "twilio" || config.Sms.Provider == "twilio_verify" + default: + return false + } +} diff --git a/auth_v2.169.0/internal/api/sms_provider/sms_provider_test.go b/auth_v2.169.0/internal/api/sms_provider/sms_provider_test.go new file mode 100644 index 0000000..e5b5216 --- /dev/null +++ b/auth_v2.169.0/internal/api/sms_provider/sms_provider_test.go @@ -0,0 +1,287 @@ +package sms_provider + +import ( + "encoding/base64" + "fmt" + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "gopkg.in/h2non/gock.v1" +) + +var handleApiRequest func(*http.Request) (*http.Response, error) + +type SmsProviderTestSuite struct { + suite.Suite + Config *conf.GlobalConfiguration +} + +type MockHttpClient struct { + mock.Mock +} + +func (m *MockHttpClient) Do(req *http.Request) (*http.Response, error) { + return handleApiRequest(req) +} + +func TestSmsProvider(t *testing.T) { + ts := &SmsProviderTestSuite{ + Config: &conf.GlobalConfiguration{ + Sms: conf.SmsProviderConfiguration{ + Twilio: conf.TwilioProviderConfiguration{ + AccountSid: "test_account_sid", + AuthToken: "test_auth_token", + MessageServiceSid: "test_message_service_id", + }, + TwilioVerify: conf.TwilioVerifyProviderConfiguration{ + AccountSid: "test_account_sid", + AuthToken: "test_auth_token", + MessageServiceSid: "test_message_service_id", + }, + Messagebird: conf.MessagebirdProviderConfiguration{ + AccessKey: "test_access_key", + Originator: "test_originator", + }, + Vonage: conf.VonageProviderConfiguration{ + ApiKey: "test_api_key", + ApiSecret: "test_api_secret", + From: "test_from", + }, + Textlocal: conf.TextlocalProviderConfiguration{ + ApiKey: "test_api_key", + Sender: "test_sender", + }, + }, + }, + } + suite.Run(t, ts) +} + +func (ts *SmsProviderTestSuite) TestTwilioSendSms() { + defer gock.Off() + provider, err := NewTwilioProvider(ts.Config.Sms.Twilio) + require.NoError(ts.T(), err) + + twilioProvider, ok := provider.(*TwilioProvider) + require.Equal(ts.T(), true, ok) + + phone := "123456789" + message := "This is the sms code: 123456" + + body := url.Values{ + "To": {"+" + phone}, + "Channel": {"sms"}, + "From": {twilioProvider.Config.MessageServiceSid}, + "Body": {message}, + } + + cases := []struct { + Desc string + TwilioResponse *gock.Response + ExpectedError error + OTP string + }{ + { + Desc: "Successfully sent sms", + TwilioResponse: gock.New(twilioProvider.APIPath).Post(""). + MatchHeader("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(twilioProvider.Config.AccountSid+":"+twilioProvider.Config.AuthToken))). + MatchType("url").BodyString(body.Encode()). + Reply(200).JSON(SmsStatus{ + To: "+" + phone, + From: twilioProvider.Config.MessageServiceSid, + Status: "sent", + Body: message, + MessageSID: "abcdef", + }), + OTP: "123456", + ExpectedError: nil, + }, + { + Desc: "Sms status is failed / undelivered", + TwilioResponse: gock.New(twilioProvider.APIPath).Post(""). + MatchHeader("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(twilioProvider.Config.AccountSid+":"+twilioProvider.Config.AuthToken))). + MatchType("url").BodyString(body.Encode()). + Reply(200).JSON(SmsStatus{ + ErrorMessage: "failed to send sms", + ErrorCode: "401", + Status: "failed", + MessageSID: "abcdef", + }), + ExpectedError: fmt.Errorf("twilio error: %v %v for message %v", "failed to send sms", "401", "abcdef"), + }, + { + Desc: "Non-2xx status code returned", + TwilioResponse: gock.New(twilioProvider.APIPath).Post(""). + MatchHeader("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(twilioProvider.Config.AccountSid+":"+twilioProvider.Config.AuthToken))). + MatchType("url").BodyString(body.Encode()). + Reply(500).JSON(twilioErrResponse{ + Code: 500, + Message: "Internal server error", + MoreInfo: "error", + Status: 500, + }), + OTP: "123456", + ExpectedError: &twilioErrResponse{ + Code: 500, + Message: "Internal server error", + MoreInfo: "error", + Status: 500, + }, + }, + } + + for _, c := range cases { + ts.Run(c.Desc, func() { + _, err = twilioProvider.SendSms(phone, message, SMSProvider, c.OTP) + require.Equal(ts.T(), c.ExpectedError, err) + }) + } +} + +func (ts *SmsProviderTestSuite) TestMessagebirdSendSms() { + defer gock.Off() + provider, err := NewMessagebirdProvider(ts.Config.Sms.Messagebird) + require.NoError(ts.T(), err) + + messagebirdProvider, ok := provider.(*MessagebirdProvider) + require.Equal(ts.T(), true, ok) + + phone := "123456789" + message := "This is the sms code: 123456" + body := url.Values{ + "originator": {messagebirdProvider.Config.Originator}, + "body": {message}, + "recipients": {phone}, + "type": {"sms"}, + "datacoding": {"unicode"}, + } + gock.New(messagebirdProvider.APIPath).Post("").MatchHeader("Authorization", "AccessKey "+messagebirdProvider.Config.AccessKey).MatchType("url").BodyString(body.Encode()).Reply(200).JSON(MessagebirdResponse{ + Recipients: MessagebirdResponseRecipients{ + TotalSentCount: 1, + }, + }) + + _, err = messagebirdProvider.SendSms(phone, message) + require.NoError(ts.T(), err) +} + +func (ts *SmsProviderTestSuite) TestVonageSendSms() { + defer gock.Off() + provider, err := NewVonageProvider(ts.Config.Sms.Vonage) + require.NoError(ts.T(), err) + + vonageProvider, ok := provider.(*VonageProvider) + require.Equal(ts.T(), true, ok) + + phone := "123456789" + message := "This is the sms code: 123456" + + body := url.Values{ + "from": {vonageProvider.Config.From}, + "to": {phone}, + "text": {message}, + "api_key": {vonageProvider.Config.ApiKey}, + "api_secret": {vonageProvider.Config.ApiSecret}, + } + + gock.New(vonageProvider.APIPath).Post("").MatchType("url").BodyString(body.Encode()).Reply(200).JSON(VonageResponse{ + Messages: []VonageResponseMessage{ + {Status: "0"}, + }, + }) + + _, err = vonageProvider.SendSms(phone, message) + require.NoError(ts.T(), err) +} + +func (ts *SmsProviderTestSuite) TestTextLocalSendSms() { + defer gock.Off() + provider, err := NewTextlocalProvider(ts.Config.Sms.Textlocal) + require.NoError(ts.T(), err) + + textlocalProvider, ok := provider.(*TextlocalProvider) + require.Equal(ts.T(), true, ok) + + phone := "123456789" + message := "This is the sms code: 123456" + body := url.Values{ + "sender": {textlocalProvider.Config.Sender}, + "apikey": {textlocalProvider.Config.ApiKey}, + "message": {message}, + "numbers": {phone}, + } + + gock.New(textlocalProvider.APIPath).Post("").MatchType("url").BodyString(body.Encode()).Reply(200).JSON(TextlocalResponse{ + Status: "success", + Errors: []TextlocalError{}, + }) + + _, err = textlocalProvider.SendSms(phone, message) + require.NoError(ts.T(), err) +} +func (ts *SmsProviderTestSuite) TestTwilioVerifySendSms() { + defer gock.Off() + provider, err := NewTwilioVerifyProvider(ts.Config.Sms.TwilioVerify) + require.NoError(ts.T(), err) + + twilioVerifyProvider, ok := provider.(*TwilioVerifyProvider) + require.Equal(ts.T(), true, ok) + + phone := "123456789" + message := "This is the sms code: 123456" + + body := url.Values{ + "To": {"+" + phone}, + "Channel": {"sms"}, + } + + cases := []struct { + Desc string + TwilioResponse *gock.Response + ExpectedError error + }{ + { + Desc: "Successfully sent sms", + TwilioResponse: gock.New(twilioVerifyProvider.APIPath).Post(""). + MatchHeader("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(twilioVerifyProvider.Config.AccountSid+":"+twilioVerifyProvider.Config.AuthToken))). + MatchType("url").BodyString(body.Encode()). + Reply(200).JSON(SmsStatus{ + To: "+" + phone, + From: twilioVerifyProvider.Config.MessageServiceSid, + Status: "sent", + Body: message, + }), + ExpectedError: nil, + }, + { + Desc: "Non-2xx status code returned", + TwilioResponse: gock.New(twilioVerifyProvider.APIPath).Post(""). + MatchHeader("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(twilioVerifyProvider.Config.AccountSid+":"+twilioVerifyProvider.Config.AuthToken))). + MatchType("url").BodyString(body.Encode()). + Reply(500).JSON(twilioErrResponse{ + Code: 500, + Message: "Internal server error", + MoreInfo: "error", + Status: 500, + }), + ExpectedError: &twilioErrResponse{ + Code: 500, + Message: "Internal server error", + MoreInfo: "error", + Status: 500, + }, + }, + } + + for _, c := range cases { + ts.Run(c.Desc, func() { + _, err = twilioVerifyProvider.SendSms(phone, message, SMSProvider) + require.Equal(ts.T(), c.ExpectedError, err) + }) + } +} diff --git a/auth_v2.169.0/internal/api/sms_provider/textlocal.go b/auth_v2.169.0/internal/api/sms_provider/textlocal.go new file mode 100644 index 0000000..ef07a6f --- /dev/null +++ b/auth_v2.169.0/internal/api/sms_provider/textlocal.go @@ -0,0 +1,107 @@ +package sms_provider + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/utilities" +) + +const ( + defaultTextLocalApiBase = "https://api.textlocal.in" + textLocalTemplateErrorCode = 80 +) + +type TextlocalProvider struct { + Config *conf.TextlocalProviderConfiguration + APIPath string +} + +type TextlocalError struct { + Code int `json:"code"` + Message string `json:"message"` +} + +type TextlocalResponse struct { + Status string `json:"status"` + Errors []TextlocalError `json:"errors"` + Messages []TextlocalMessage `json:"messages"` +} + +type TextlocalMessage struct { + MessageID string `json:"id"` +} + +// Creates a SmsProvider with the Textlocal Config +func NewTextlocalProvider(config conf.TextlocalProviderConfiguration) (SmsProvider, error) { + if err := config.Validate(); err != nil { + return nil, err + } + + apiPath := defaultTextLocalApiBase + "/send" + return &TextlocalProvider{ + Config: &config, + APIPath: apiPath, + }, nil +} + +func (t *TextlocalProvider) SendMessage(phone, message, channel, otp string) (string, error) { + switch channel { + case SMSProvider: + return t.SendSms(phone, message) + default: + return "", fmt.Errorf("channel type %q is not supported for TextLocal", channel) + } +} + +// Send an SMS containing the OTP with Textlocal's API +func (t *TextlocalProvider) SendSms(phone string, message string) (string, error) { + body := url.Values{ + "sender": {t.Config.Sender}, + "apikey": {t.Config.ApiKey}, + "message": {message}, + "numbers": {phone}, + } + + client := &http.Client{Timeout: defaultTimeout} + r, err := http.NewRequest("POST", t.APIPath, strings.NewReader(body.Encode())) + if err != nil { + return "", err + } + + r.Header.Add("Content-Type", "application/x-www-form-urlencoded") + res, err := client.Do(r) + if err != nil { + return "", err + } + defer utilities.SafeClose(res.Body) + + resp := &TextlocalResponse{} + derr := json.NewDecoder(res.Body).Decode(resp) + if derr != nil { + return "", derr + } + + messageID := "" + + if resp.Status != "success" { + if len(resp.Messages) > 0 { + messageID = resp.Messages[0].MessageID + } + + if len(resp.Errors) > 0 && resp.Errors[0].Code == textLocalTemplateErrorCode { + return messageID, fmt.Errorf("textlocal error: %v (code: %v) template message: %s", resp.Errors[0].Message, resp.Errors[0].Code, message) + } + + return messageID, fmt.Errorf("textlocal error: %v (code: %v) message %s", resp.Errors[0].Message, resp.Errors[0].Code, messageID) + } + + return messageID, nil +} +func (t *TextlocalProvider) VerifyOTP(phone, code string) error { + return fmt.Errorf("VerifyOTP is not supported for Textlocal") +} diff --git a/auth_v2.169.0/internal/api/sms_provider/twilio.go b/auth_v2.169.0/internal/api/sms_provider/twilio.go new file mode 100644 index 0000000..3536c2f --- /dev/null +++ b/auth_v2.169.0/internal/api/sms_provider/twilio.go @@ -0,0 +1,141 @@ +package sms_provider + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "regexp" + "strings" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/utilities" +) + +const ( + defaultTwilioApiBase = "https://api.twilio.com" + apiVersion = "2010-04-01" +) + +type TwilioProvider struct { + Config *conf.TwilioProviderConfiguration + APIPath string +} + +var isPhoneNumber = regexp.MustCompile("^[1-9][0-9]{1,14}$") + +// formatPhoneNumber removes "+" and whitespaces in a phone number +func formatPhoneNumber(phone string) string { + return strings.ReplaceAll(strings.TrimPrefix(phone, "+"), " ", "") +} + +type SmsStatus struct { + To string `json:"to"` + From string `json:"from"` + MessageSID string `json:"sid"` + Status string `json:"status"` + ErrorCode string `json:"error_code"` + ErrorMessage string `json:"error_message"` + Body string `json:"body"` +} + +type twilioErrResponse struct { + Code int `json:"code"` + Message string `json:"message"` + MoreInfo string `json:"more_info"` + Status int `json:"status"` +} + +func (t twilioErrResponse) Error() string { + return fmt.Sprintf("%s More information: %s", t.Message, t.MoreInfo) +} + +// Creates a SmsProvider with the Twilio Config +func NewTwilioProvider(config conf.TwilioProviderConfiguration) (SmsProvider, error) { + if err := config.Validate(); err != nil { + return nil, err + } + + apiPath := defaultTwilioApiBase + "/" + apiVersion + "/" + "Accounts" + "/" + config.AccountSid + "/Messages.json" + return &TwilioProvider{ + Config: &config, + APIPath: apiPath, + }, nil +} + +func (t *TwilioProvider) SendMessage(phone, message, channel, otp string) (string, error) { + switch channel { + case SMSProvider, WhatsappProvider: + return t.SendSms(phone, message, channel, otp) + default: + return "", fmt.Errorf("channel type %q is not supported for Twilio", channel) + } +} + +// Send an SMS containing the OTP with Twilio's API +func (t *TwilioProvider) SendSms(phone, message, channel, otp string) (string, error) { + sender := t.Config.MessageServiceSid + receiver := "+" + phone + body := url.Values{ + "To": {receiver}, // twilio api requires "+" extension to be included + "Channel": {channel}, + "From": {sender}, + "Body": {message}, + } + if channel == WhatsappProvider { + receiver = channel + ":" + receiver + if isPhoneNumber.MatchString(formatPhoneNumber(sender)) { + sender = channel + ":" + sender + } + + // Programmable Messaging (WhatsApp) takes in different set of inputs + body = url.Values{ + "To": {receiver}, // twilio api requires "+" extension to be included + "Channel": {channel}, + "From": {sender}, + } + // For backward compatibility with old API. + if t.Config.ContentSid != "" { + // Used to substitute OTP. See https://www.twilio.com/docs/content/whatsappauthentication for more details + contentVariables := fmt.Sprintf(`{"1": "%s"}`, otp) + body.Set("ContentSid", t.Config.ContentSid) + body.Set("ContentVariables", contentVariables) + } else { + body.Set("Body", message) + } + } + client := &http.Client{Timeout: defaultTimeout} + r, err := http.NewRequest("POST", t.APIPath, strings.NewReader(body.Encode())) + if err != nil { + return "", err + } + r.Header.Add("Content-Type", "application/x-www-form-urlencoded") + r.SetBasicAuth(t.Config.AccountSid, t.Config.AuthToken) + res, err := client.Do(r) + if err != nil { + return "", err + } + defer utilities.SafeClose(res.Body) + if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusCreated { + resp := &twilioErrResponse{} + if err := json.NewDecoder(res.Body).Decode(resp); err != nil { + return "", err + } + return "", resp + } + // validate sms status + resp := &SmsStatus{} + derr := json.NewDecoder(res.Body).Decode(resp) + if derr != nil { + return "", derr + } + + if resp.Status == "failed" || resp.Status == "undelivered" { + return resp.MessageSID, fmt.Errorf("twilio error: %v %v for message %s", resp.ErrorMessage, resp.ErrorCode, resp.MessageSID) + } + + return resp.MessageSID, nil +} +func (t *TwilioProvider) VerifyOTP(phone, code string) error { + return fmt.Errorf("VerifyOTP is not supported for Twilio") +} diff --git a/auth_v2.169.0/internal/api/sms_provider/twilio_verify.go b/auth_v2.169.0/internal/api/sms_provider/twilio_verify.go new file mode 100644 index 0000000..8ec5463 --- /dev/null +++ b/auth_v2.169.0/internal/api/sms_provider/twilio_verify.go @@ -0,0 +1,139 @@ +package sms_provider + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/utilities" +) + +const ( + verifyServiceApiBase = "https://verify.twilio.com/v2/Services/" +) + +type TwilioVerifyProvider struct { + Config *conf.TwilioVerifyProviderConfiguration + APIPath string +} + +type VerificationResponse struct { + To string `json:"to"` + Status string `json:"status"` + Channel string `json:"channel"` + Valid bool `json:"valid"` + VerificationSID string `json:"sid"` + ErrorCode string `json:"error_code"` + ErrorMessage string `json:"error_message"` +} + +// See: https://www.twilio.com/docs/verify/api/verification-check +type VerificationCheckResponse struct { + To string `json:"to"` + Status string `json:"status"` + Channel string `json:"channel"` + Valid bool `json:"valid"` + ErrorCode string `json:"error_code"` + ErrorMessage string `json:"error_message"` +} + +// Creates a SmsProvider with the Twilio Config +func NewTwilioVerifyProvider(config conf.TwilioVerifyProviderConfiguration) (SmsProvider, error) { + if err := config.Validate(); err != nil { + return nil, err + } + apiPath := verifyServiceApiBase + config.MessageServiceSid + "/Verifications" + + return &TwilioVerifyProvider{ + Config: &config, + APIPath: apiPath, + }, nil +} + +func (t *TwilioVerifyProvider) SendMessage(phone, message, channel, otp string) (string, error) { + switch channel { + case SMSProvider, WhatsappProvider: + return t.SendSms(phone, message, channel) + default: + return "", fmt.Errorf("channel type %q is not supported for Twilio", channel) + } +} + +// Send an SMS containing the OTP with Twilio's API +func (t *TwilioVerifyProvider) SendSms(phone, message, channel string) (string, error) { + // Unlike Programmable Messaging, Verify does not require a prefix for channel + receiver := "+" + phone + body := url.Values{ + "To": {receiver}, + "Channel": {channel}, + } + client := &http.Client{Timeout: defaultTimeout} + r, err := http.NewRequest("POST", t.APIPath, strings.NewReader(body.Encode())) + if err != nil { + return "", err + } + r.Header.Add("Content-Type", "application/x-www-form-urlencoded") + r.SetBasicAuth(t.Config.AccountSid, t.Config.AuthToken) + res, err := client.Do(r) + if err != nil { + return "", err + } + defer utilities.SafeClose(res.Body) + if !(res.StatusCode == http.StatusOK || res.StatusCode == http.StatusCreated) { + resp := &twilioErrResponse{} + if err := json.NewDecoder(res.Body).Decode(resp); err != nil { + return "", err + } + return "", resp + } + + resp := &VerificationResponse{} + derr := json.NewDecoder(res.Body).Decode(resp) + if derr != nil { + return "", derr + } + return resp.VerificationSID, nil +} + +func (t *TwilioVerifyProvider) VerifyOTP(phone, code string) error { + verifyPath := verifyServiceApiBase + t.Config.MessageServiceSid + "/VerificationCheck" + receiver := "+" + phone + + body := url.Values{ + "To": {receiver}, // twilio api requires "+" extension to be included + "Code": {code}, + } + client := &http.Client{Timeout: defaultTimeout} + r, err := http.NewRequest("POST", verifyPath, strings.NewReader(body.Encode())) + if err != nil { + return err + } + r.Header.Add("Content-Type", "application/x-www-form-urlencoded") + r.SetBasicAuth(t.Config.AccountSid, t.Config.AuthToken) + res, err := client.Do(r) + if err != nil { + return err + } + defer utilities.SafeClose(res.Body) + if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusCreated { + resp := &twilioErrResponse{} + if err := json.NewDecoder(res.Body).Decode(resp); err != nil { + return err + } + return resp + } + resp := &VerificationCheckResponse{} + derr := json.NewDecoder(res.Body).Decode(resp) + if derr != nil { + return derr + } + + if resp.Status != "approved" || !resp.Valid { + return fmt.Errorf("twilio verification error: %v %v", resp.ErrorMessage, resp.Status) + } + + return nil +} diff --git a/auth_v2.169.0/internal/api/sms_provider/vonage.go b/auth_v2.169.0/internal/api/sms_provider/vonage.go new file mode 100644 index 0000000..4b9fd5b --- /dev/null +++ b/auth_v2.169.0/internal/api/sms_provider/vonage.go @@ -0,0 +1,105 @@ +package sms_provider + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/utilities" + "golang.org/x/exp/utf8string" +) + +const ( + defaultVonageApiBase = "https://rest.nexmo.com" +) + +type VonageProvider struct { + Config *conf.VonageProviderConfiguration + APIPath string +} + +type VonageResponseMessage struct { + MessageID string `json:"message-id"` + Status string `json:"status"` + ErrorText string `json:"error-text"` +} + +type VonageResponse struct { + Messages []VonageResponseMessage `json:"messages"` +} + +// Creates a SmsProvider with the Vonage Config +func NewVonageProvider(config conf.VonageProviderConfiguration) (SmsProvider, error) { + if err := config.Validate(); err != nil { + return nil, err + } + + apiPath := defaultVonageApiBase + "/sms/json" + return &VonageProvider{ + Config: &config, + APIPath: apiPath, + }, nil +} + +func (t *VonageProvider) SendMessage(phone, message, channel, otp string) (string, error) { + switch channel { + case SMSProvider: + return t.SendSms(phone, message) + default: + return "", fmt.Errorf("channel type %q is not supported for Vonage", channel) + } +} + +// Send an SMS containing the OTP with Vonage's API +func (t *VonageProvider) SendSms(phone string, message string) (string, error) { + body := url.Values{ + "from": {t.Config.From}, + "to": {phone}, + "text": {message}, + "api_key": {t.Config.ApiKey}, + "api_secret": {t.Config.ApiSecret}, + } + + isMessageContainUnicode := !utf8string.NewString(message).IsASCII() + if isMessageContainUnicode { + body.Set("type", "unicode") + } + + client := &http.Client{Timeout: defaultTimeout} + r, err := http.NewRequest("POST", t.APIPath, strings.NewReader(body.Encode())) + if err != nil { + return "", err + } + + r.Header.Add("Content-Type", "application/x-www-form-urlencoded") + res, err := client.Do(r) + if err != nil { + return "", err + } + defer utilities.SafeClose(res.Body) + + resp := &VonageResponse{} + derr := json.NewDecoder(res.Body).Decode(resp) + if derr != nil { + return "", derr + } + + if len(resp.Messages) <= 0 { + return "", errors.New("vonage error: Internal Error") + } + + // A status of zero indicates success; a non-zero value means something went wrong. + if resp.Messages[0].Status != "0" { + return resp.Messages[0].MessageID, fmt.Errorf("vonage error: %v (status: %v) for message %s", resp.Messages[0].ErrorText, resp.Messages[0].Status, resp.Messages[0].MessageID) + } + + return resp.Messages[0].MessageID, nil +} + +func (t *VonageProvider) VerifyOTP(phone, code string) error { + return fmt.Errorf("VerifyOTP is not supported for Vonage") +} diff --git a/auth_v2.169.0/internal/api/sorting.go b/auth_v2.169.0/internal/api/sorting.go new file mode 100644 index 0000000..f951d95 --- /dev/null +++ b/auth_v2.169.0/internal/api/sorting.go @@ -0,0 +1,41 @@ +package api + +import ( + "fmt" + "net/http" + "strings" + + "github.com/supabase/auth/internal/models" +) + +func sort(r *http.Request, allowedFields map[string]bool, defaultSort []models.SortField) (*models.SortParams, error) { + sortParams := &models.SortParams{ + Fields: defaultSort, + } + urlParams := r.URL.Query() + if values, exists := urlParams["sort"]; exists && len(values) > 0 { + sortParams.Fields = []models.SortField{} + for _, value := range values { + parts := strings.SplitN(value, " ", 2) + field := parts[0] + if _, ok := allowedFields[field]; !ok { + return nil, fmt.Errorf("bad field for sort '%v'", field) + } + + dir := models.Descending + if len(parts) > 1 { + switch strings.ToUpper(parts[1]) { + case string(models.Ascending): + dir = models.Ascending + case string(models.Descending): + dir = models.Descending + default: + return nil, fmt.Errorf("bad direction for sort '%v', only 'asc' and 'desc' allowed", parts[1]) + } + } + sortParams.Fields = append(sortParams.Fields, models.SortField{Name: field, Dir: dir}) + } + } + + return sortParams, nil +} diff --git a/auth_v2.169.0/internal/api/sso.go b/auth_v2.169.0/internal/api/sso.go new file mode 100644 index 0000000..1003407 --- /dev/null +++ b/auth_v2.169.0/internal/api/sso.go @@ -0,0 +1,147 @@ +package api + +import ( + "net/http" + + "github.com/crewjam/saml" + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +type SingleSignOnParams struct { + ProviderID uuid.UUID `json:"provider_id"` + Domain string `json:"domain"` + RedirectTo string `json:"redirect_to"` + SkipHTTPRedirect *bool `json:"skip_http_redirect"` + CodeChallenge string `json:"code_challenge"` + CodeChallengeMethod string `json:"code_challenge_method"` +} + +type SingleSignOnResponse struct { + URL string `json:"url"` +} + +func (p *SingleSignOnParams) validate() (bool, error) { + hasProviderID := p.ProviderID != uuid.Nil + hasDomain := p.Domain != "" + + if hasProviderID && hasDomain { + return hasProviderID, badRequestError(ErrorCodeValidationFailed, "Only one of provider_id or domain supported") + } else if !hasProviderID && !hasDomain { + return hasProviderID, badRequestError(ErrorCodeValidationFailed, "A provider_id or domain needs to be provided") + } + + return hasProviderID, nil +} + +// SingleSignOn handles the single-sign-on flow for a provided SSO domain or provider. +func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + params := &SingleSignOnParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + var err error + hasProviderID := false + + if hasProviderID, err = params.validate(); err != nil { + return err + } + codeChallengeMethod := params.CodeChallengeMethod + codeChallenge := params.CodeChallenge + + if err := validatePKCEParams(codeChallengeMethod, codeChallenge); err != nil { + return err + } + flowType := getFlowFromChallenge(params.CodeChallenge) + var flowStateID *uuid.UUID + flowStateID = nil + if isPKCEFlow(flowType) { + flowState, err := generateFlowState(db, models.SSOSAML.String(), models.SSOSAML, codeChallengeMethod, codeChallenge, nil) + if err != nil { + return err + } + flowStateID = &flowState.ID + } + + var ssoProvider *models.SSOProvider + + if hasProviderID { + ssoProvider, err = models.FindSSOProviderByID(db, params.ProviderID) + if models.IsNotFoundError(err) { + return notFoundError(ErrorCodeSSOProviderNotFound, "No such SSO provider") + } else if err != nil { + return internalServerError("Unable to find SSO provider by ID").WithInternalError(err) + } + } else { + ssoProvider, err = models.FindSSOProviderByDomain(db, params.Domain) + if models.IsNotFoundError(err) { + return notFoundError(ErrorCodeSSOProviderNotFound, "No SSO provider assigned for this domain") + } else if err != nil { + return internalServerError("Unable to find SSO provider by domain").WithInternalError(err) + } + } + + entityDescriptor, err := ssoProvider.SAMLProvider.EntityDescriptor() + if err != nil { + return internalServerError("Error parsing SAML Metadata for SAML provider").WithInternalError(err) + } + + serviceProvider := a.getSAMLServiceProvider(entityDescriptor, false /* <- idpInitiated */) + + authnRequest, err := serviceProvider.MakeAuthenticationRequest( + serviceProvider.GetSSOBindingLocation(saml.HTTPRedirectBinding), + saml.HTTPRedirectBinding, + saml.HTTPPostBinding, + ) + if err != nil { + return internalServerError("Error creating SAML Authentication Request").WithInternalError(err) + } + + // Some IdPs do not support the use of the `persistent` NameID format, + // and require a different format to be sent to work. + if ssoProvider.SAMLProvider.NameIDFormat != nil { + authnRequest.NameIDPolicy.Format = ssoProvider.SAMLProvider.NameIDFormat + } + + relayState := models.SAMLRelayState{ + SSOProviderID: ssoProvider.ID, + RequestID: authnRequest.ID, + RedirectTo: params.RedirectTo, + FlowStateID: flowStateID, + } + + if err := db.Transaction(func(tx *storage.Connection) error { + if terr := tx.Create(&relayState); terr != nil { + return internalServerError("Error creating SAML relay state from sign up").WithInternalError(err) + } + + return nil + }); err != nil { + return err + } + + ssoRedirectURL, err := authnRequest.Redirect(relayState.ID.String(), serviceProvider) + if err != nil { + return internalServerError("Error creating SAML authentication request redirect URL").WithInternalError(err) + } + + skipHTTPRedirect := false + + if params.SkipHTTPRedirect != nil { + skipHTTPRedirect = *params.SkipHTTPRedirect + } + + if skipHTTPRedirect { + return sendJSON(w, http.StatusOK, SingleSignOnResponse{ + URL: ssoRedirectURL.String(), + }) + } + + http.Redirect(w, r, ssoRedirectURL.String(), http.StatusSeeOther) + return nil +} diff --git a/auth_v2.169.0/internal/api/sso_test.go b/auth_v2.169.0/internal/api/sso_test.go new file mode 100644 index 0000000..bae1beb --- /dev/null +++ b/auth_v2.169.0/internal/api/sso_test.go @@ -0,0 +1,752 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + jwt "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +const dateInPast = "2001-02-03T04:05:06.789" +const dateInFarFuture = "2999-02-03T04:05:06.789" +const oneHour = "PT1H" + +type SSOTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration + AdminJWT string +} + +func TestSSO(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &SSOTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + if config.SAML.Enabled { + suite.Run(t, ts) + } +} + +func (ts *SSOTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + claims := &AccessTokenClaims{ + Role: "supabase_admin", + } + token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString([]byte(ts.Config.JWT.Secret)) + require.NoError(ts.T(), err, "Error generating admin jwt") + + ts.AdminJWT = token +} + +func (ts *SSOTestSuite) TestNonAdminJWT() { + // TODO +} + +func (ts *SSOTestSuite) TestAdminListEmptySSOProviders() { + req := httptest.NewRequest(http.MethodGet, "http://localhost/admin/sso/providers", nil) + req.Header.Set("Authorization", "Bearer "+ts.AdminJWT) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusOK, w.Code) + + body, err := io.ReadAll(w.Body) + require.NoError(ts.T(), err) + + var result struct { + Items []interface{} `json:"items"` + NextToken string `json:"next_token"` + } + + require.NoError(ts.T(), json.Unmarshal(body, &result)) + require.Equal(ts.T(), len(result.Items), 0) + require.Equal(ts.T(), result.NextToken, "") +} + +func (ts *SSOTestSuite) TestAdminGetSSOProviderNotExist() { + examples := []struct { + URL string + }{ + { + URL: "http://localhost/admin/sso/providers/not-a-uuid", + }, + { + URL: "http://localhost/admin/sso/providers/677477db-3f51-4038-bc05-c6bb9bdc3c32", + }, + } + + for _, example := range examples { + req := httptest.NewRequest(http.MethodGet, example.URL, nil) + req.Header.Set("Authorization", "Bearer "+ts.AdminJWT) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusNotFound, w.Code) + } +} + +func configurableSAMLIDPMetadata(entityID, validUntil, cacheDuration string) string { + return fmt.Sprintf(` + + + + + MIIDdDCCAlygAwIBAgIGAYKSjRZiMA0GCSqGSIb3DQEBCwUAMHsxFDASBgNVBAoTC0dvb2dsZSBJ +bmMuMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MQ8wDQYDVQQDEwZHb29nbGUxGDAWBgNVBAsTD0dv +b2dsZSBGb3IgV29yazELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWEwHhcNMjIwODEy +MTQ1NDU1WhcNMjcwODExMTQ1NDU1WjB7MRQwEgYDVQQKEwtHb29nbGUgSW5jLjEWMBQGA1UEBxMN +TW91bnRhaW4gVmlldzEPMA0GA1UEAxMGR29vZ2xlMRgwFgYDVQQLEw9Hb29nbGUgRm9yIFdvcmsx +CzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAlncFzErcnZm7ZWO71NZStnCIAoYNKf6Uw3LPLzcvk0YrA/eBC3PVDHSfahi+apGO +Ytdq7IQUvBdto3rJTvP49fjyO0WLbAbiPC+dILt2Gx9kttxpSp99Bf+8ObL/fTy5Y2oHbJBfBX1V +qfDQIY0fcej3AndFYUOE0gZXyeSbnROB8W1PzHxOc7rq1mlas0rvyja7AK4gwXjIwyIGsFDmHnve +buqWOYMzOT9oD+iQq9BWYVHkXGZn0BXzKtnw9w8I3IxQdndUoCl95pYRIvdl1b0dWdO9cXtSsTkL +kAa8B/mCQcF4W2M3t/yKtrcLcRTALg3/Hc+Xz+3BpY/fSDk1SwIDAQABMA0GCSqGSIb3DQEBCwUA +A4IBAQCER02WLf6bKwTGVD/3VTntetIiETuPs46Dum8blbsg+2BYdAHIQcB9cLuMRosIw0nYj54m +SfiyfoWGcx3CkMup1MtKyWu+SqDHl9Bpf+GFLG0ngKD/zB6xwpv/TCi+g/FBYe2TvzD6B1V0z7Vs +Xf+Gc2TWBKmCuKf/g2AUt7IQLpOaqxuJVoZjp4sEMov6d3FnaoHQEd0lg+XmnYfLNtwe3QRSU0BD +x6lVV4kXi0x0n198/gkjnA85rPZoZ6dmqHtkcM0Gabgg6KEE5ubSDlWDsdv27uANceCZAoxd1+in +4/KqqkhynnbJs7Op5ZX8cckiHGGTGHNb35kys/XukuCo + + + + urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress + + + +`, entityID, validUntil, cacheDuration, entityID, entityID) + +} + +func (ts *SSOTestSuite) TestIsStaleSAMLMetadata() { + + // https://en.wikipedia.org/wiki/ISO_8601 + currentTime := time.Now() + currentTimeAsISO8601 := currentTime.UTC().Format("2006-01-02T15:04:05Z07:00") + examples := []struct { + Description string + Metadata []byte + IsStale bool + CacheDurationExceeded bool + }{ + { + Description: "Metadata is valid and within cache duration", + Metadata: []byte(configurableSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-B", dateInFarFuture, oneHour)), + IsStale: false, + CacheDurationExceeded: false, + }, + { + + Description: "Metadata is valid but is a minute past cache duration", + Metadata: []byte(configurableSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-B", currentTimeAsISO8601, oneHour)), + IsStale: true, + CacheDurationExceeded: true, + }, + + { + Description: "Metadata is invalid but within cache duration", + Metadata: []byte(configurableSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-B", dateInPast, oneHour)), + IsStale: true, + CacheDurationExceeded: false, + }, + } + + for _, example := range examples { + metadata, err := parseSAMLMetadata(example.Metadata) + require.NoError(ts.T(), err) + provider := models.SAMLProvider{ + EntityID: metadata.EntityID, + MetadataXML: string(example.Metadata), + UpdatedAt: currentTime, + } + if example.CacheDurationExceeded { + provider.UpdatedAt = currentTime.Add(-time.Minute * 59) + } + + require.Equal(ts.T(), example.IsStale, IsSAMLMetadataStale(metadata, provider)) + } + +} + +func validSAMLIDPMetadata(entityID string) string { + return fmt.Sprintf(` + + + + + MIIDdDCCAlygAwIBAgIGAYKSjRZiMA0GCSqGSIb3DQEBCwUAMHsxFDASBgNVBAoTC0dvb2dsZSBJ +bmMuMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MQ8wDQYDVQQDEwZHb29nbGUxGDAWBgNVBAsTD0dv +b2dsZSBGb3IgV29yazELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWEwHhcNMjIwODEy +MTQ1NDU1WhcNMjcwODExMTQ1NDU1WjB7MRQwEgYDVQQKEwtHb29nbGUgSW5jLjEWMBQGA1UEBxMN +TW91bnRhaW4gVmlldzEPMA0GA1UEAxMGR29vZ2xlMRgwFgYDVQQLEw9Hb29nbGUgRm9yIFdvcmsx +CzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAlncFzErcnZm7ZWO71NZStnCIAoYNKf6Uw3LPLzcvk0YrA/eBC3PVDHSfahi+apGO +Ytdq7IQUvBdto3rJTvP49fjyO0WLbAbiPC+dILt2Gx9kttxpSp99Bf+8ObL/fTy5Y2oHbJBfBX1V +qfDQIY0fcej3AndFYUOE0gZXyeSbnROB8W1PzHxOc7rq1mlas0rvyja7AK4gwXjIwyIGsFDmHnve +buqWOYMzOT9oD+iQq9BWYVHkXGZn0BXzKtnw9w8I3IxQdndUoCl95pYRIvdl1b0dWdO9cXtSsTkL +kAa8B/mCQcF4W2M3t/yKtrcLcRTALg3/Hc+Xz+3BpY/fSDk1SwIDAQABMA0GCSqGSIb3DQEBCwUA +A4IBAQCER02WLf6bKwTGVD/3VTntetIiETuPs46Dum8blbsg+2BYdAHIQcB9cLuMRosIw0nYj54m +SfiyfoWGcx3CkMup1MtKyWu+SqDHl9Bpf+GFLG0ngKD/zB6xwpv/TCi+g/FBYe2TvzD6B1V0z7Vs +Xf+Gc2TWBKmCuKf/g2AUt7IQLpOaqxuJVoZjp4sEMov6d3FnaoHQEd0lg+XmnYfLNtwe3QRSU0BD +x6lVV4kXi0x0n198/gkjnA85rPZoZ6dmqHtkcM0Gabgg6KEE5ubSDlWDsdv27uANceCZAoxd1+in +4/KqqkhynnbJs7Op5ZX8cckiHGGTGHNb35kys/XukuCo + + + + urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress + + + +`, entityID, entityID, entityID) +} + +func (ts *SSOTestSuite) TestAdminCreateSSOProvider() { + examples := []struct { + StatusCode int + Request map[string]interface{} + }{ + { + StatusCode: http.StatusBadRequest, + Request: map[string]interface{}{}, + }, + { + StatusCode: http.StatusBadRequest, + Request: map[string]interface{}{ + "type": "saml", + }, + }, + { + StatusCode: http.StatusBadRequest, + Request: map[string]interface{}{ + "type": "oidc", + }, + }, + { + StatusCode: http.StatusCreated, + Request: map[string]interface{}{ + "type": "saml", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-A"), + }, + }, + { + StatusCode: http.StatusCreated, + Request: map[string]interface{}{ + "type": "saml", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-B"), + }, + }, + { + StatusCode: http.StatusCreated, + Request: map[string]interface{}{ + "type": "saml", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-DUPLICATE"), + }, + }, + { + StatusCode: http.StatusCreated, + Request: map[string]interface{}{ + "type": "saml", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-WITH-ATTRIBUTE-MAPPING"), + "attribute_mapping": map[string]interface{}{ + "keys": map[string]interface{}{ + "username": map[string]interface{}{ + "name": "mail", + }, + }, + }, + }, + }, + { + StatusCode: http.StatusUnprocessableEntity, + Request: map[string]interface{}{ + "type": "saml", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-DUPLICATE"), + }, + }, + { + StatusCode: http.StatusCreated, + Request: map[string]interface{}{ + "type": "saml", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-WITH-DOMAIN-A"), + "domains": []string{ + "example.com", + }, + }, + }, + { + StatusCode: http.StatusBadRequest, + Request: map[string]interface{}{ + "type": "saml", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-WITH-DOMAIN-B"), + "domains": []string{ + "example.com", + }, + }, + }, + { + StatusCode: http.StatusBadRequest, + Request: map[string]interface{}{ + "type": "saml", + "metadata_url": "https://accounts.google.com/o/saml2?idpid=EXAMPLE-WITH-METADATA-URL-TOO", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-WITH-METADATA-URL-TOO"), + }, + }, + { + StatusCode: http.StatusBadRequest, + Request: map[string]interface{}{ + "type": "saml", + "metadata_url": "http://accounts.google.com/o/saml2?idpid=EXAMPLE-WITH-METADATA-OVER-HTTP", + }, + }, + { + StatusCode: http.StatusBadRequest, + Request: map[string]interface{}{ + "type": "saml", + "metadata_url": "https://accounts.google.com\\o/saml2?idpid=EXAMPLE-WITH-INVALID-METADATA-URL", + }, + }, + // TODO: add example with metadata_url + } + + for i, example := range examples { + body, err := json.Marshal(example.Request) + require.NoError(ts.T(), err) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/admin/sso/providers", bytes.NewBuffer(body)) + req.Header.Set("Authorization", "Bearer "+ts.AdminJWT) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + response, err := io.ReadAll(w.Body) + require.NoError(ts.T(), err) + + require.Equal(ts.T(), example.StatusCode, w.Code, "Example %d failed with body %q", i, response) + + if example.StatusCode != http.StatusCreated { + continue + } + + // now check if the provider can be queried (GET) + var provider struct { + ID string `json:"id"` + } + + require.NoError(ts.T(), json.Unmarshal(response, &provider)) + + req = httptest.NewRequest(http.MethodGet, "http://localhost/admin/sso/providers/"+provider.ID, nil) + req.Header.Set("Authorization", "Bearer "+ts.AdminJWT) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + response, err = io.ReadAll(w.Body) + require.NoError(ts.T(), err) + + require.Equal(ts.T(), http.StatusOK, w.Code) + + originalProviderID := provider.ID + provider.ID = "" + + require.NoError(ts.T(), json.Unmarshal(response, &provider)) + require.Equal(ts.T(), provider.ID, originalProviderID) + + // now check if the provider can be queried (List) + var providers struct { + Items []struct { + ID string `json:"id"` + } `json:"items"` + } + + req = httptest.NewRequest(http.MethodGet, "http://localhost/admin/sso/providers", nil) + req.Header.Set("Authorization", "Bearer "+ts.AdminJWT) + w = httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + response, err = io.ReadAll(w.Body) + require.NoError(ts.T(), err) + + require.NoError(ts.T(), json.Unmarshal(response, &providers)) + + contained := false + for _, listProvider := range providers.Items { + if listProvider.ID == provider.ID { + contained = true + break + } + } + + require.True(ts.T(), contained) + } +} + +func (ts *SSOTestSuite) TestAdminUpdateSSOProvider() { + providers := []struct { + ID string + Request map[string]interface{} + }{ + { + Request: map[string]interface{}{ + "type": "saml", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-A"), + }, + }, + { + Request: map[string]interface{}{ + "type": "saml", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-C"), + "domains": []string{ + "example.com", + }, + }, + }, + } + + for i, example := range providers { + body, err := json.Marshal(example.Request) + require.NoError(ts.T(), err) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/admin/sso/providers", bytes.NewBuffer(body)) + req.Header.Set("Authorization", "Bearer "+ts.AdminJWT) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + response, err := io.ReadAll(w.Body) + require.NoError(ts.T(), err) + + var payload struct { + ID string `json:"id"` + } + + require.NoError(ts.T(), json.Unmarshal(response, &payload)) + + providers[i].ID = payload.ID + } + + examples := []struct { + ID string + Status int + Request map[string]interface{} + }{ + { + ID: providers[0].ID, + Status: http.StatusBadRequest, // changing entity ID + Request: map[string]interface{}{ + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-B"), + }, + }, + { + ID: providers[0].ID, + Status: http.StatusBadRequest, // domain already exists + Request: map[string]interface{}{ + "domains": []string{ + "example.com", + }, + }, + }, + { + ID: providers[1].ID, + Status: http.StatusOK, + Request: map[string]interface{}{ + "domains": []string{ + "example.com", + "example.org", + }, + }, + }, + { + ID: providers[1].ID, + Status: http.StatusOK, + Request: map[string]interface{}{ + "attribute_mapping": map[string]interface{}{ + "keys": map[string]interface{}{ + "username": map[string]interface{}{ + "name": "mail", + }, + }, + }, + }, + }, + } + + for _, example := range examples { + body, err := json.Marshal(example.Request) + require.NoError(ts.T(), err) + + req := httptest.NewRequest(http.MethodPut, "http://localhost/admin/sso/providers/"+example.ID, bytes.NewBuffer(body)) + req.Header.Set("Authorization", "Bearer "+ts.AdminJWT) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), w.Code, example.Status) + } +} + +func (ts *SSOTestSuite) TestAdminDeleteSSOProvider() { + providers := []struct { + ID string + Request map[string]interface{} + }{ + { + Request: map[string]interface{}{ + "type": "saml", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-A"), + }, + }, + } + + for i, example := range providers { + body, err := json.Marshal(example.Request) + require.NoError(ts.T(), err) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/admin/sso/providers", bytes.NewBuffer(body)) + req.Header.Set("Authorization", "Bearer "+ts.AdminJWT) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + response, err := io.ReadAll(w.Body) + require.NoError(ts.T(), err) + + var payload struct { + ID string `json:"id"` + } + + require.NoError(ts.T(), json.Unmarshal(response, &payload)) + + providers[i].ID = payload.ID + } + + examples := []struct { + ID string + Status int + }{ + { + ID: providers[0].ID, + Status: http.StatusOK, + }, + } + + for _, example := range examples { + req := httptest.NewRequest(http.MethodDelete, "http://localhost/admin/sso/providers/"+example.ID, nil) + req.Header.Set("Authorization", "Bearer "+ts.AdminJWT) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), w.Code, example.Status) + } + + check := []struct { + ID string + }{ + { + ID: providers[0].ID, + }, + } + + for _, example := range check { + req := httptest.NewRequest(http.MethodGet, "http://localhost/admin/sso/providers/"+example.ID, nil) + req.Header.Set("Authorization", "Bearer "+ts.AdminJWT) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusNotFound, w.Code) + } +} + +func (ts *SSOTestSuite) TestSingleSignOn() { + providers := []struct { + ID string + Request map[string]interface{} + }{ + { + // creates a SAML provider (EXAMPLE-A) + // does not have a domain mapping + Request: map[string]interface{}{ + "type": "saml", + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-A"), + }, + }, + { + // creates a SAML provider (EXAMPLE-B) + // does have a domain mapping on example.com + Request: map[string]interface{}{ + "type": "saml", + "domains": []string{ + "example.com", + }, + "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-B"), + }, + }, + } + + for i, example := range providers { + body, err := json.Marshal(example.Request) + require.NoError(ts.T(), err) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/admin/sso/providers", bytes.NewBuffer(body)) + req.Header.Set("Authorization", "Bearer "+ts.AdminJWT) + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + response, err := io.ReadAll(w.Body) + require.NoError(ts.T(), err) + + var payload struct { + ID string `json:"id"` + } + + require.NoError(ts.T(), json.Unmarshal(response, &payload)) + + providers[i].ID = payload.ID + } + + examples := []struct { + Code int + Request map[string]interface{} + URL string + }{ + { + // call /sso with provider_id (EXAMPLE-A) + // should be successful and redirect to the EXAMPLE-A SSO URL + Request: map[string]interface{}{ + "provider_id": providers[0].ID, + }, + Code: http.StatusSeeOther, + URL: "https://accounts.google.com/o/saml2?idpid=EXAMPLE-A", + }, + { + // call /sso with provider_id (EXAMPLE-A) and SSO PKCE + // should be successful and redirect to the EXAMPLE-A SSO URL + Request: map[string]interface{}{ + "provider_id": providers[0].ID, + "code_challenge": "vby3iMQ4XUuycKkEyNsYHXshPql1Dod7Ebey2iXTXm4", + "code_challenge_method": "s256", + }, + Code: http.StatusSeeOther, + URL: "https://accounts.google.com/o/saml2?idpid=EXAMPLE-A", + }, + { + // call /sso with domain=example.com (provider=EXAMPLE-B) + // should be successful and redirect to the EXAMPLE-B SSO URL + Request: map[string]interface{}{ + "domain": "example.com", + }, + Code: http.StatusSeeOther, + URL: "https://accounts.google.com/o/saml2?idpid=EXAMPLE-B", + }, + { + // call /sso with domain=example.com (provider=EXAMPLE-B) + // should be successful and redirect to the EXAMPLE-B SSO URL + Request: map[string]interface{}{ + "domain": "example.com", + "skip_http_redirect": true, + }, + Code: http.StatusOK, + URL: "https://accounts.google.com/o/saml2?idpid=EXAMPLE-B", + }, + { + // call /sso with domain=example.org (no such provider) + // should be unsuccessful with 404 + Request: map[string]interface{}{ + "domain": "example.org", + }, + Code: http.StatusNotFound, + }, + { + // call /sso with a provider_id= (no such provider) + // should be unsuccessful with 404 + Request: map[string]interface{}{ + "provider_id": "14d906bf-9bd5-4734-b7d1-3904e240610e", + }, + Code: http.StatusNotFound, + }, + } + + for _, example := range examples { + body, err := json.Marshal(example.Request) + require.NoError(ts.T(), err) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/sso", bytes.NewBuffer(body)) + // no authorization header intentional, this is a login endpoint + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), w.Code, example.Code) + + locationURLString := "" + + if example.Code == http.StatusSeeOther { + locationURLString = w.Header().Get("Location") + } else if example.Code == http.StatusOK { + var response struct { + URL string `json:"url"` + } + + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response)) + + require.NotEmpty(ts.T(), response.URL) + + locationURLString = response.URL + } else { + continue + } + + locationURL, err := url.ParseRequestURI(locationURLString) + require.NoError(ts.T(), err) + + locationQuery, err := url.ParseQuery(locationURL.RawQuery) + + require.NoError(ts.T(), err) + + samlQueryParams := []string{ + "SAMLRequest", + "RelayState", + "SigAlg", + "Signature", + } + + for _, param := range samlQueryParams { + require.True(ts.T(), locationQuery.Has(param)) + } + + for _, param := range samlQueryParams { + locationQuery.Del(param) + } + + locationURL.RawQuery = locationQuery.Encode() + + require.Equal(ts.T(), locationURL.String(), example.URL) + } +} + +func TestSSOCreateParamsValidation(t *testing.T) { + // TODO +} diff --git a/auth_v2.169.0/internal/api/ssoadmin.go b/auth_v2.169.0/internal/api/ssoadmin.go new file mode 100644 index 0000000..20fd8b9 --- /dev/null +++ b/auth_v2.169.0/internal/api/ssoadmin.go @@ -0,0 +1,421 @@ +package api + +import ( + "context" + "io" + "net/http" + "net/url" + "strings" + "unicode/utf8" + + "github.com/crewjam/saml" + "github.com/crewjam/saml/samlsp" + "github.com/go-chi/chi/v5" + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" +) + +// loadSSOProvider looks for an idp_id parameter in the URL route and loads the SSO provider +// with that ID (or resource ID) and adds it to the context. +func (a *API) loadSSOProvider(w http.ResponseWriter, r *http.Request) (context.Context, error) { + ctx := r.Context() + db := a.db.WithContext(ctx) + + idpParam := chi.URLParam(r, "idp_id") + + idpID, err := uuid.FromString(idpParam) + if err != nil { + // idpParam is not UUIDv4 + return nil, notFoundError(ErrorCodeSSOProviderNotFound, "SSO Identity Provider not found") + } + + // idpParam is a UUIDv4 + provider, err := models.FindSSOProviderByID(db, idpID) + if err != nil { + if models.IsNotFoundError(err) { + return nil, notFoundError(ErrorCodeSSOProviderNotFound, "SSO Identity Provider not found") + } else { + return nil, internalServerError("Database error finding SSO Identity Provider").WithInternalError(err) + } + } + + observability.LogEntrySetField(r, "sso_provider_id", provider.ID.String()) + + return withSSOProvider(r.Context(), provider), nil +} + +// adminSSOProvidersList lists all SAML SSO Identity Providers in the system. Does +// not deal with pagination at this time. +func (a *API) adminSSOProvidersList(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + providers, err := models.FindAllSAMLProviders(db) + if err != nil { + return err + } + + for i := range providers { + // remove metadata XML so that the returned JSON is not ginormous + providers[i].SAMLProvider.MetadataXML = "" + } + + return sendJSON(w, http.StatusOK, map[string]interface{}{ + "items": providers, + }) +} + +type CreateSSOProviderParams struct { + Type string `json:"type"` + + MetadataURL string `json:"metadata_url"` + MetadataXML string `json:"metadata_xml"` + Domains []string `json:"domains"` + AttributeMapping models.SAMLAttributeMapping `json:"attribute_mapping"` + NameIDFormat string `json:"name_id_format"` +} + +func (p *CreateSSOProviderParams) validate(forUpdate bool) error { + if !forUpdate && p.Type != "saml" { + return badRequestError(ErrorCodeValidationFailed, "Only 'saml' supported for SSO provider type") + } else if p.MetadataURL != "" && p.MetadataXML != "" { + return badRequestError(ErrorCodeValidationFailed, "Only one of metadata_xml or metadata_url needs to be set") + } else if !forUpdate && p.MetadataURL == "" && p.MetadataXML == "" { + return badRequestError(ErrorCodeValidationFailed, "Either metadata_xml or metadata_url must be set") + } else if p.MetadataURL != "" { + metadataURL, err := url.ParseRequestURI(p.MetadataURL) + if err != nil { + return badRequestError(ErrorCodeValidationFailed, "metadata_url is not a valid URL") + } + + if metadataURL.Scheme != "https" { + return badRequestError(ErrorCodeValidationFailed, "metadata_url is not a HTTPS URL") + } + } + + switch p.NameIDFormat { + case "", + string(saml.PersistentNameIDFormat), + string(saml.EmailAddressNameIDFormat), + string(saml.TransientNameIDFormat), + string(saml.UnspecifiedNameIDFormat): + // it's valid + + default: + return badRequestError(ErrorCodeValidationFailed, "name_id_format must be unspecified or one of %v", strings.Join([]string{ + string(saml.PersistentNameIDFormat), + string(saml.EmailAddressNameIDFormat), + string(saml.TransientNameIDFormat), + string(saml.UnspecifiedNameIDFormat), + }, ", ")) + } + + return nil +} + +func (p *CreateSSOProviderParams) metadata(ctx context.Context) ([]byte, *saml.EntityDescriptor, error) { + var rawMetadata []byte + var err error + + if p.MetadataXML != "" { + rawMetadata = []byte(p.MetadataXML) + } else if p.MetadataURL != "" { + rawMetadata, err = fetchSAMLMetadata(ctx, p.MetadataURL) + if err != nil { + return nil, nil, err + } + } else { + // impossible situation if you called validate() prior + return nil, nil, nil + } + + metadata, err := parseSAMLMetadata(rawMetadata) + if err != nil { + return nil, nil, err + } + + return rawMetadata, metadata, nil +} + +func parseSAMLMetadata(rawMetadata []byte) (*saml.EntityDescriptor, error) { + if !utf8.Valid(rawMetadata) { + return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata XML contains invalid UTF-8 characters, which are not supported at this time") + } + + metadata, err := samlsp.ParseMetadata(rawMetadata) + if err != nil { + return nil, err + } + + if metadata.EntityID == "" { + return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata does not contain an EntityID") + } + + if len(metadata.IDPSSODescriptors) < 1 { + return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata does not contain any IDPSSODescriptor") + } + + if len(metadata.IDPSSODescriptors) > 1 { + return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata contains multiple IDPSSODescriptors") + } + + return metadata, nil +} + +func fetchSAMLMetadata(ctx context.Context, url string) ([]byte, error) { + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, internalServerError("Unable to create a request to metadata_url").WithInternalError(err) + } + + req = req.WithContext(ctx) + + req.Header.Set("Accept", "application/xml;charset=UTF-8") + req.Header.Set("Accept-Charset", "UTF-8") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + + defer utilities.SafeClose(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, badRequestError(ErrorCodeSAMLMetadataFetchFailed, "HTTP %v error fetching SAML Metadata from URL '%s'", resp.StatusCode, url) + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + return data, nil +} + +// adminSSOProvidersCreate creates a new SAML Identity Provider in the system. +func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + params := &CreateSSOProviderParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + if err := params.validate(false /* <- forUpdate */); err != nil { + return err + } + + rawMetadata, metadata, err := params.metadata(ctx) + if err != nil { + return err + } + + existingProvider, err := models.FindSAMLProviderByEntityID(db, metadata.EntityID) + if err != nil && !models.IsNotFoundError(err) { + return err + } + if existingProvider != nil { + return unprocessableEntityError(ErrorCodeSAMLIdPAlreadyExists, "SAML Identity Provider with this EntityID (%s) already exists", metadata.EntityID) + } + + provider := &models.SSOProvider{ + // TODO handle Name, Description, Attribute Mapping + SAMLProvider: models.SAMLProvider{ + EntityID: metadata.EntityID, + MetadataXML: string(rawMetadata), + }, + } + + if params.MetadataURL != "" { + provider.SAMLProvider.MetadataURL = ¶ms.MetadataURL + } + + if params.NameIDFormat != "" { + provider.SAMLProvider.NameIDFormat = ¶ms.NameIDFormat + } + + provider.SAMLProvider.AttributeMapping = params.AttributeMapping + + for _, domain := range params.Domains { + existingProvider, err := models.FindSSOProviderByDomain(db, domain) + if err != nil && !models.IsNotFoundError(err) { + return err + } + if existingProvider != nil { + return badRequestError(ErrorCodeSSODomainAlreadyExists, "SSO Domain '%s' is already assigned to an SSO identity provider (%s)", domain, existingProvider.ID.String()) + } + + provider.SSODomains = append(provider.SSODomains, models.SSODomain{ + Domain: domain, + }) + } + + if err := db.Transaction(func(tx *storage.Connection) error { + if terr := tx.Eager().Create(provider); terr != nil { + return terr + } + + return tx.Eager().Load(provider) + }); err != nil { + return err + } + + return sendJSON(w, http.StatusCreated, provider) +} + +// adminSSOProvidersGet returns an existing SAML Identity Provider in the system. +func (a *API) adminSSOProvidersGet(w http.ResponseWriter, r *http.Request) error { + provider := getSSOProvider(r.Context()) + + return sendJSON(w, http.StatusOK, provider) +} + +// adminSSOProvidersUpdate updates a provider with the provided diff values. +func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + params := &CreateSSOProviderParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + if err := params.validate(true /* <- forUpdate */); err != nil { + return err + } + + modified := false + updateSAMLProvider := false + + provider := getSSOProvider(ctx) + + if params.MetadataXML != "" || params.MetadataURL != "" { + // metadata is being updated + rawMetadata, metadata, err := params.metadata(ctx) + if err != nil { + return err + } + + if provider.SAMLProvider.EntityID != metadata.EntityID { + return badRequestError(ErrorCodeSAMLEntityIDMismatch, "SAML Metadata can be updated only if the EntityID matches for the provider; expected '%s' but got '%s'", provider.SAMLProvider.EntityID, metadata.EntityID) + } + + if params.MetadataURL != "" { + provider.SAMLProvider.MetadataURL = ¶ms.MetadataURL + } + + provider.SAMLProvider.MetadataXML = string(rawMetadata) + updateSAMLProvider = true + modified = true + } + + // domains are being "updated" only when params.Domains is not nil, if + // it was nil (but not `[]`) then the caller is expecting not to modify + // the domains + updateDomains := params.Domains != nil + + var createDomains, deleteDomains []models.SSODomain + keepDomains := make(map[string]bool) + + for _, domain := range params.Domains { + existingProvider, err := models.FindSSOProviderByDomain(db, domain) + if err != nil && !models.IsNotFoundError(err) { + return err + } + if existingProvider != nil { + if existingProvider.ID == provider.ID { + keepDomains[domain] = true + } else { + return badRequestError(ErrorCodeSSODomainAlreadyExists, "SSO domain '%s' already assigned to another provider (%s)", domain, existingProvider.ID.String()) + } + } else { + modified = true + createDomains = append(createDomains, models.SSODomain{ + Domain: domain, + SSOProviderID: provider.ID, + }) + } + } + + if updateDomains { + for i, domain := range provider.SSODomains { + if !keepDomains[domain.Domain] { + modified = true + deleteDomains = append(deleteDomains, provider.SSODomains[i]) + } + } + } + + updateAttributeMapping := false + if params.AttributeMapping.Keys != nil { + updateAttributeMapping = !provider.SAMLProvider.AttributeMapping.Equal(¶ms.AttributeMapping) + if updateAttributeMapping { + modified = true + provider.SAMLProvider.AttributeMapping = params.AttributeMapping + } + } + + nameIDFormat := "" + if provider.SAMLProvider.NameIDFormat != nil { + nameIDFormat = *provider.SAMLProvider.NameIDFormat + } + + if params.NameIDFormat != nameIDFormat { + modified = true + + if params.NameIDFormat == "" { + provider.SAMLProvider.NameIDFormat = nil + } else { + provider.SAMLProvider.NameIDFormat = ¶ms.NameIDFormat + } + } + + if modified { + if err := db.Transaction(func(tx *storage.Connection) error { + if terr := tx.Eager().Update(provider); terr != nil { + return terr + } + + if updateDomains { + if terr := tx.Destroy(deleteDomains); terr != nil { + return terr + } + + if terr := tx.Eager().Create(createDomains); terr != nil { + return terr + } + } + + if updateAttributeMapping || updateSAMLProvider { + if terr := tx.Eager().Update(&provider.SAMLProvider); terr != nil { + return terr + } + } + + return tx.Eager().Load(provider) + }); err != nil { + return unprocessableEntityError(ErrorCodeConflict, "Updating SSO provider failed, likely due to a conflict. Try again?").WithInternalError(err) + } + } + + return sendJSON(w, http.StatusOK, provider) +} + +// adminSSOProvidersDelete deletes a SAML identity provider. +func (a *API) adminSSOProvidersDelete(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + provider := getSSOProvider(ctx) + + if err := db.Transaction(func(tx *storage.Connection) error { + return tx.Eager().Destroy(provider) + }); err != nil { + return err + } + + return sendJSON(w, http.StatusOK, provider) +} diff --git a/auth_v2.169.0/internal/api/token.go b/auth_v2.169.0/internal/api/token.go new file mode 100644 index 0000000..cc945f2 --- /dev/null +++ b/auth_v2.169.0/internal/api/token.go @@ -0,0 +1,506 @@ +package api + +import ( + "context" + "net/http" + "net/url" + "strconv" + "time" + + "fmt" + + "github.com/gofrs/uuid" + "github.com/golang-jwt/jwt/v5" + "github.com/xeipuuv/gojsonschema" + + "github.com/supabase/auth/internal/hooks" + "github.com/supabase/auth/internal/metering" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" +) + +// AccessTokenClaims is a struct thats used for JWT claims +type AccessTokenClaims struct { + jwt.RegisteredClaims + Email string `json:"email"` + Phone string `json:"phone"` + AppMetaData map[string]interface{} `json:"app_metadata"` + UserMetaData map[string]interface{} `json:"user_metadata"` + Role string `json:"role"` + AuthenticatorAssuranceLevel string `json:"aal,omitempty"` + AuthenticationMethodReference []models.AMREntry `json:"amr,omitempty"` + SessionId string `json:"session_id,omitempty"` + IsAnonymous bool `json:"is_anonymous"` +} + +// AccessTokenResponse represents an OAuth2 success response +type AccessTokenResponse struct { + Token string `json:"access_token"` + TokenType string `json:"token_type"` // Bearer + ExpiresIn int `json:"expires_in"` + ExpiresAt int64 `json:"expires_at"` + RefreshToken string `json:"refresh_token"` + User *models.User `json:"user"` + ProviderAccessToken string `json:"provider_token,omitempty"` + ProviderRefreshToken string `json:"provider_refresh_token,omitempty"` + WeakPassword *WeakPasswordError `json:"weak_password,omitempty"` +} + +// AsRedirectURL encodes the AccessTokenResponse as a redirect URL that +// includes the access token response data in a URL fragment. +func (r *AccessTokenResponse) AsRedirectURL(redirectURL string, extraParams url.Values) string { + extraParams.Set("access_token", r.Token) + extraParams.Set("token_type", r.TokenType) + extraParams.Set("expires_in", strconv.Itoa(r.ExpiresIn)) + extraParams.Set("expires_at", strconv.FormatInt(r.ExpiresAt, 10)) + extraParams.Set("refresh_token", r.RefreshToken) + + return redirectURL + "#" + extraParams.Encode() +} + +// PasswordGrantParams are the parameters the ResourceOwnerPasswordGrant method accepts +type PasswordGrantParams struct { + Email string `json:"email"` + Phone string `json:"phone"` + Password string `json:"password"` +} + +// PKCEGrantParams are the parameters the PKCEGrant method accepts +type PKCEGrantParams struct { + AuthCode string `json:"auth_code"` + CodeVerifier string `json:"code_verifier"` +} + +const useCookieHeader = "x-use-cookie" +const InvalidLoginMessage = "Invalid login credentials" + +// Token is the endpoint for OAuth access token requests +func (a *API) Token(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + grantType := r.FormValue("grant_type") + switch grantType { + case "password": + return a.ResourceOwnerPasswordGrant(ctx, w, r) + case "refresh_token": + return a.RefreshTokenGrant(ctx, w, r) + case "id_token": + return a.IdTokenGrant(ctx, w, r) + case "pkce": + return a.PKCE(ctx, w, r) + default: + return badRequestError(ErrorCodeInvalidCredentials, "unsupported_grant_type") + } +} + +// ResourceOwnerPasswordGrant implements the password grant type flow +func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + db := a.db.WithContext(ctx) + + params := &PasswordGrantParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + aud := a.requestAud(ctx, r) + config := a.config + + if params.Email != "" && params.Phone != "" { + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided on login.") + } + var user *models.User + var grantParams models.GrantParams + var provider string + var err error + + grantParams.FillGrantParams(r) + + if params.Email != "" { + provider = "email" + if !config.External.Email.Enabled { + return unprocessableEntityError(ErrorCodeEmailProviderDisabled, "Email logins are disabled") + } + user, err = models.FindUserByEmailAndAudience(db, params.Email, aud) + } else if params.Phone != "" { + provider = "phone" + if !config.External.Phone.Enabled { + return unprocessableEntityError(ErrorCodePhoneProviderDisabled, "Phone logins are disabled") + } + params.Phone = formatPhoneNumber(params.Phone) + user, err = models.FindUserByPhoneAndAudience(db, params.Phone, aud) + } else { + return badRequestError(ErrorCodeValidationFailed, "missing email or phone") + } + + if err != nil { + if models.IsNotFoundError(err) { + return badRequestError(ErrorCodeInvalidCredentials, InvalidLoginMessage) + } + return internalServerError("Database error querying schema").WithInternalError(err) + } + + if !user.HasPassword() { + return badRequestError(ErrorCodeInvalidCredentials, InvalidLoginMessage) + } + + if user.IsBanned() { + return badRequestError(ErrorCodeUserBanned, "User is banned") + } + + isValidPassword, shouldReEncrypt, err := user.Authenticate(ctx, db, params.Password, config.Security.DBEncryption.DecryptionKeys, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID) + if err != nil { + return err + } + + var weakPasswordError *WeakPasswordError + if isValidPassword { + if err := a.checkPasswordStrength(ctx, params.Password); err != nil { + if wpe, ok := err.(*WeakPasswordError); ok { + weakPasswordError = wpe + } else { + observability.GetLogEntry(r).Entry.WithError(err).Warn("Password strength check on sign-in failed") + } + } + + if shouldReEncrypt { + if err := user.SetPassword(ctx, params.Password, true, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey); err != nil { + return err + } + + // directly change this in the database without + // calling user.UpdatePassword() because this + // is not a password change, just encryption + // change in the database + if err := db.UpdateOnly(user, "encrypted_password"); err != nil { + return err + } + } + } + + if config.Hook.PasswordVerificationAttempt.Enabled { + input := hooks.PasswordVerificationAttemptInput{ + UserID: user.ID, + Valid: isValidPassword, + } + output := hooks.PasswordVerificationAttemptOutput{} + if err := a.invokeHook(nil, r, &input, &output); err != nil { + return err + } + + if output.Decision == hooks.HookRejection { + if output.Message == "" { + output.Message = hooks.DefaultPasswordHookRejectionMessage + } + if output.ShouldLogoutUser { + if err := models.Logout(a.db, user.ID); err != nil { + return err + } + } + return badRequestError(ErrorCodeInvalidCredentials, output.Message) + } + } + if !isValidPassword { + return badRequestError(ErrorCodeInvalidCredentials, InvalidLoginMessage) + } + + if params.Email != "" && !user.IsConfirmed() { + return badRequestError(ErrorCodeEmailNotConfirmed, "Email not confirmed") + } else if params.Phone != "" && !user.IsPhoneConfirmed() { + return badRequestError(ErrorCodePhoneNotConfirmed, "Phone not confirmed") + } + + var token *AccessTokenResponse + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + if terr = models.NewAuditLogEntry(r, tx, user, models.LoginAction, "", map[string]interface{}{ + "provider": provider, + }); terr != nil { + return terr + } + token, terr = a.issueRefreshToken(r, tx, user, models.PasswordGrant, grantParams) + if terr != nil { + return terr + } + + return nil + }) + if err != nil { + return err + } + + token.WeakPassword = weakPasswordError + + metering.RecordLogin("password", user.ID) + return sendJSON(w, http.StatusOK, token) +} + +func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + db := a.db.WithContext(ctx) + var grantParams models.GrantParams + + // There is a slight problem with this as it will pick-up the + // User-Agent and IP addresses from the server if used on the server + // side. Currently there's no mechanism to distinguish, but the server + // can be told to at least propagate the User-Agent header. + grantParams.FillGrantParams(r) + + params := &PKCEGrantParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + if params.AuthCode == "" || params.CodeVerifier == "" { + return badRequestError(ErrorCodeValidationFailed, "invalid request: both auth code and code verifier should be non-empty") + } + + flowState, err := models.FindFlowStateByAuthCode(db, params.AuthCode) + // Sanity check in case user ID was not set properly + if models.IsNotFoundError(err) || flowState.UserID == nil { + return notFoundError(ErrorCodeFlowStateNotFound, "invalid flow state, no valid flow state found") + } else if err != nil { + return err + } + if flowState.IsExpired(a.config.External.FlowStateExpiryDuration) { + return unprocessableEntityError(ErrorCodeFlowStateExpired, "invalid flow state, flow state has expired") + } + + user, err := models.FindUserByID(db, *flowState.UserID) + if err != nil { + return err + } + if err := flowState.VerifyPKCE(params.CodeVerifier); err != nil { + return badRequestError(ErrorCodeBadCodeVerifier, err.Error()) + } + + var token *AccessTokenResponse + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + authMethod, err := models.ParseAuthenticationMethod(flowState.AuthenticationMethod) + if err != nil { + return err + } + if terr := models.NewAuditLogEntry(r, tx, user, models.LoginAction, "", map[string]interface{}{ + "provider_type": flowState.ProviderType, + }); terr != nil { + return terr + } + token, terr = a.issueRefreshToken(r, tx, user, authMethod, grantParams) + if terr != nil { + // error type is already handled in issueRefreshToken + return terr + } + token.ProviderAccessToken = flowState.ProviderAccessToken + // Because not all providers give out a refresh token + // See corresponding OAuth2 spec: + if flowState.ProviderRefreshToken != "" { + token.ProviderRefreshToken = flowState.ProviderRefreshToken + } + if terr = tx.Destroy(flowState); terr != nil { + return err + } + return nil + }) + if err != nil { + return err + } + + return sendJSON(w, http.StatusOK, token) +} + +func (a *API) generateAccessToken(r *http.Request, tx *storage.Connection, user *models.User, sessionId *uuid.UUID, authenticationMethod models.AuthenticationMethod) (string, int64, error) { + config := a.config + if sessionId == nil { + return "", 0, internalServerError("Session is required to issue access token") + } + sid := sessionId.String() + session, terr := models.FindSessionByID(tx, *sessionId, false) + if terr != nil { + return "", 0, terr + } + aal, amr, terr := session.CalculateAALAndAMR(user) + if terr != nil { + return "", 0, terr + } + + issuedAt := time.Now().UTC() + expiresAt := issuedAt.Add(time.Second * time.Duration(config.JWT.Exp)) + + claims := &hooks.AccessTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: user.ID.String(), + Audience: jwt.ClaimStrings{user.Aud}, + IssuedAt: jwt.NewNumericDate(issuedAt), + ExpiresAt: jwt.NewNumericDate(expiresAt), + Issuer: config.JWT.Issuer, + }, + Email: user.GetEmail(), + Phone: user.GetPhone(), + AppMetaData: user.AppMetaData, + UserMetaData: user.UserMetaData, + Role: user.Role, + SessionId: sid, + AuthenticatorAssuranceLevel: aal.String(), + AuthenticationMethodReference: amr, + IsAnonymous: user.IsAnonymous, + } + + var gotrueClaims jwt.Claims = claims + if config.Hook.CustomAccessToken.Enabled { + input := hooks.CustomAccessTokenInput{ + UserID: user.ID, + Claims: claims, + AuthenticationMethod: authenticationMethod.String(), + } + + output := hooks.CustomAccessTokenOutput{} + + err := a.invokeHook(tx, r, &input, &output) + if err != nil { + return "", 0, err + } + gotrueClaims = jwt.MapClaims(output.Claims) + } + + signed, err := signJwt(&config.JWT, gotrueClaims) + if err != nil { + return "", 0, err + } + + return signed, expiresAt.Unix(), nil +} + +func (a *API) issueRefreshToken(r *http.Request, conn *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod, grantParams models.GrantParams) (*AccessTokenResponse, error) { + config := a.config + + now := time.Now() + user.LastSignInAt = &now + + var tokenString string + var expiresAt int64 + var refreshToken *models.RefreshToken + + err := conn.Transaction(func(tx *storage.Connection) error { + var terr error + + refreshToken, terr = models.GrantAuthenticatedUser(tx, user, grantParams) + if terr != nil { + return internalServerError("Database error granting user").WithInternalError(terr) + } + + terr = models.AddClaimToSession(tx, *refreshToken.SessionId, authenticationMethod) + if terr != nil { + return terr + } + + tokenString, expiresAt, terr = a.generateAccessToken(r, tx, user, refreshToken.SessionId, authenticationMethod) + if terr != nil { + // Account for Hook Error + httpErr, ok := terr.(*HTTPError) + if ok { + return httpErr + } + return internalServerError("error generating jwt token").WithInternalError(terr) + } + return nil + }) + if err != nil { + return nil, err + } + + return &AccessTokenResponse{ + Token: tokenString, + TokenType: "bearer", + ExpiresIn: config.JWT.Exp, + ExpiresAt: expiresAt, + RefreshToken: refreshToken.Token, + User: user, + }, nil +} + +func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod, grantParams models.GrantParams) (*AccessTokenResponse, error) { + ctx := r.Context() + config := a.config + var tokenString string + var expiresAt int64 + var refreshToken *models.RefreshToken + currentClaims := getClaims(ctx) + sessionId, err := uuid.FromString(currentClaims.SessionId) + if err != nil { + return nil, internalServerError("Cannot read SessionId claim as UUID").WithInternalError(err) + } + + err = tx.Transaction(func(tx *storage.Connection) error { + if terr := models.AddClaimToSession(tx, sessionId, authenticationMethod); terr != nil { + return terr + } + session, terr := models.FindSessionByID(tx, sessionId, false) + if terr != nil { + return terr + } + currentToken, terr := models.FindTokenBySessionID(tx, &session.ID) + if terr != nil { + return terr + } + if err := tx.Load(user, "Identities"); err != nil { + return err + } + // Swap to ensure current token is the latest one + refreshToken, terr = models.GrantRefreshTokenSwap(r, tx, user, currentToken) + if terr != nil { + return terr + } + aal, _, terr := session.CalculateAALAndAMR(user) + if terr != nil { + return terr + } + + if err := session.UpdateAALAndAssociatedFactor(tx, aal, grantParams.FactorID); err != nil { + return err + } + + tokenString, expiresAt, terr = a.generateAccessToken(r, tx, user, &session.ID, authenticationMethod) + if terr != nil { + httpErr, ok := terr.(*HTTPError) + if ok { + return httpErr + } + return internalServerError("error generating jwt token").WithInternalError(terr) + } + return nil + }) + if err != nil { + return nil, err + } + return &AccessTokenResponse{ + Token: tokenString, + TokenType: "bearer", + ExpiresIn: config.JWT.Exp, + ExpiresAt: expiresAt, + RefreshToken: refreshToken.Token, + User: user, + }, nil +} + +func validateTokenClaims(outputClaims map[string]interface{}) error { + schemaLoader := gojsonschema.NewStringLoader(hooks.MinimumViableTokenSchema) + + documentLoader := gojsonschema.NewGoLoader(outputClaims) + + result, err := gojsonschema.Validate(schemaLoader, documentLoader) + if err != nil { + return err + } + + if !result.Valid() { + var errorMessages string + + for _, desc := range result.Errors() { + errorMessages += fmt.Sprintf("- %s\n", desc) + fmt.Printf("- %s\n", desc) + } + return fmt.Errorf("output claims do not conform to the expected schema: \n%s", errorMessages) + + } + + return nil +} diff --git a/auth_v2.169.0/internal/api/token_oidc.go b/auth_v2.169.0/internal/api/token_oidc.go new file mode 100644 index 0000000..b576742 --- /dev/null +++ b/auth_v2.169.0/internal/api/token_oidc.go @@ -0,0 +1,253 @@ +package api + +import ( + "context" + "crypto/sha256" + "fmt" + "net/http" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" +) + +// IdTokenGrantParams are the parameters the IdTokenGrant method accepts +type IdTokenGrantParams struct { + IdToken string `json:"id_token"` + AccessToken string `json:"access_token"` + Nonce string `json:"nonce"` + Provider string `json:"provider"` + ClientID string `json:"client_id"` + Issuer string `json:"issuer"` +} + +func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.GlobalConfiguration, r *http.Request) (*oidc.Provider, bool, string, []string, error) { + log := observability.GetLogEntry(r).Entry + + var cfg *conf.OAuthProviderConfiguration + var issuer string + var providerType string + var acceptableClientIDs []string + + switch true { + case p.Provider == "apple" || p.Issuer == provider.IssuerApple: + cfg = &config.External.Apple + providerType = "apple" + issuer = provider.IssuerApple + acceptableClientIDs = append(acceptableClientIDs, config.External.Apple.ClientID...) + + if config.External.IosBundleId != "" { + acceptableClientIDs = append(acceptableClientIDs, config.External.IosBundleId) + } + + case p.Provider == "google" || p.Issuer == provider.IssuerGoogle: + cfg = &config.External.Google + providerType = "google" + issuer = provider.IssuerGoogle + acceptableClientIDs = append(acceptableClientIDs, config.External.Google.ClientID...) + + case p.Provider == "azure" || provider.IsAzureIssuer(p.Issuer): + issuer = p.Issuer + if issuer == "" || !provider.IsAzureIssuer(issuer) { + detectedIssuer, err := provider.DetectAzureIDTokenIssuer(ctx, p.IdToken) + if err != nil { + return nil, false, "", nil, badRequestError(ErrorCodeValidationFailed, "Unable to detect issuer in ID token for Azure provider").WithInternalError(err) + } + issuer = detectedIssuer + } + cfg = &config.External.Azure + providerType = "azure" + acceptableClientIDs = append(acceptableClientIDs, config.External.Azure.ClientID...) + + case p.Provider == "facebook" || p.Issuer == provider.IssuerFacebook: + cfg = &config.External.Facebook + providerType = "facebook" + issuer = provider.IssuerFacebook + acceptableClientIDs = append(acceptableClientIDs, config.External.Facebook.ClientID...) + + case p.Provider == "keycloak" || (config.External.Keycloak.Enabled && config.External.Keycloak.URL != "" && p.Issuer == config.External.Keycloak.URL): + cfg = &config.External.Keycloak + providerType = "keycloak" + issuer = config.External.Keycloak.URL + acceptableClientIDs = append(acceptableClientIDs, config.External.Keycloak.ClientID...) + + case p.Provider == "kakao" || p.Issuer == provider.IssuerKakao: + cfg = &config.External.Kakao + providerType = "kakao" + issuer = provider.IssuerKakao + acceptableClientIDs = append(acceptableClientIDs, config.External.Kakao.ClientID...) + + case p.Provider == "vercel_marketplace" || p.Issuer == provider.IssuerVercelMarketplace: + cfg = &config.External.VercelMarketplace + providerType = "vercel_marketplace" + issuer = provider.IssuerVercelMarketplace + acceptableClientIDs = append(acceptableClientIDs, config.External.VercelMarketplace.ClientID...) + + default: + log.WithField("issuer", p.Issuer).WithField("client_id", p.ClientID).Warn("Use of POST /token with arbitrary issuer and client_id is deprecated for security reasons. Please switch to using the API with provider only!") + + allowed := false + for _, allowedIssuer := range config.External.AllowedIdTokenIssuers { + if p.Issuer == allowedIssuer { + allowed = true + providerType = allowedIssuer + acceptableClientIDs = []string{p.ClientID} + issuer = allowedIssuer + break + } + } + + if !allowed { + return nil, false, "", nil, badRequestError(ErrorCodeValidationFailed, fmt.Sprintf("Custom OIDC provider %q not allowed", p.Provider)) + } + + cfg = &conf.OAuthProviderConfiguration{ + Enabled: true, + SkipNonceCheck: false, + } + } + + if !cfg.Enabled { + return nil, false, "", nil, badRequestError(ErrorCodeProviderDisabled, fmt.Sprintf("Provider (issuer %q) is not enabled", issuer)) + } + + oidcProvider, err := oidc.NewProvider(ctx, issuer) + if err != nil { + return nil, false, "", nil, err + } + + return oidcProvider, cfg.SkipNonceCheck, providerType, acceptableClientIDs, nil +} + +// IdTokenGrant implements the id_token grant type flow +func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + log := observability.GetLogEntry(r).Entry + + db := a.db.WithContext(ctx) + config := a.config + + params := &IdTokenGrantParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + if params.IdToken == "" { + return oauthError("invalid request", "id_token required") + } + + if params.Provider == "" && (params.ClientID == "" || params.Issuer == "") { + return oauthError("invalid request", "provider or client_id and issuer required") + } + + oidcProvider, skipNonceCheck, providerType, acceptableClientIDs, err := params.getProvider(ctx, config, r) + if err != nil { + return err + } + + idToken, userData, err := provider.ParseIDToken(ctx, oidcProvider, nil, params.IdToken, provider.ParseIDTokenOptions{ + SkipAccessTokenCheck: params.AccessToken == "", + AccessToken: params.AccessToken, + }) + if err != nil { + return oauthError("invalid request", "Bad ID token").WithInternalError(err) + } + + userData.Metadata.EmailVerified = false + for _, email := range userData.Emails { + if email.Primary { + userData.Metadata.Email = email.Email + userData.Metadata.EmailVerified = email.Verified + break + } else { + userData.Metadata.Email = email.Email + userData.Metadata.EmailVerified = email.Verified + } + } + + if idToken.Subject == "" { + return oauthError("invalid request", "Missing sub claim in id_token") + } + + correctAudience := false + for _, clientID := range acceptableClientIDs { + if clientID == "" { + continue + } + + for _, aud := range idToken.Audience { + if aud == clientID { + correctAudience = true + break + } + } + + if correctAudience { + break + } + } + + if !correctAudience { + return oauthError("invalid request", fmt.Sprintf("Unacceptable audience in id_token: %v", idToken.Audience)) + } + + if !skipNonceCheck { + tokenHasNonce := idToken.Nonce != "" + paramsHasNonce := params.Nonce != "" + + if tokenHasNonce != paramsHasNonce { + return oauthError("invalid request", "Passed nonce and nonce in id_token should either both exist or not.") + } else if tokenHasNonce && paramsHasNonce { + // verify nonce to mitigate replay attacks + hash := fmt.Sprintf("%x", sha256.Sum256([]byte(params.Nonce))) + if hash != idToken.Nonce { + return oauthError("invalid nonce", "Nonces mismatch") + } + } + } + + if params.AccessToken == "" { + if idToken.AccessTokenHash != "" { + log.Warn("ID token has a at_hash claim, but no access_token parameter was provided. In future versions, access_token will be mandatory as it's security best practice.") + } + } else { + if idToken.AccessTokenHash == "" { + log.Info("ID token does not have a at_hash claim, access_token parameter is unused.") + } + } + + var token *AccessTokenResponse + var grantParams models.GrantParams + + grantParams.FillGrantParams(r) + + if err := db.Transaction(func(tx *storage.Connection) error { + var user *models.User + var terr error + + user, terr = a.createAccountFromExternalIdentity(tx, r, userData, providerType) + if terr != nil { + return terr + } + + token, terr = a.issueRefreshToken(r, tx, user, models.OAuth, grantParams) + if terr != nil { + return terr + } + + return nil + }); err != nil { + switch err.(type) { + case *storage.CommitWithError: + return err + case *HTTPError: + return err + default: + return oauthError("server_error", "Internal Server Error").WithInternalError(err) + } + } + + return sendJSON(w, http.StatusOK, token) +} diff --git a/auth_v2.169.0/internal/api/token_oidc_test.go b/auth_v2.169.0/internal/api/token_oidc_test.go new file mode 100644 index 0000000..1eab99e --- /dev/null +++ b/auth_v2.169.0/internal/api/token_oidc_test.go @@ -0,0 +1,69 @@ +package api + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" +) + +type TokenOIDCTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestTokenOIDC(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &TokenOIDCTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func SetupTestOIDCProvider(ts *TokenOIDCTestSuite) *httptest.Server { + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/.well-known/openid-configuration": + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"issuer":"` + server.URL + `","authorization_endpoint":"` + server.URL + `/authorize","token_endpoint":"` + server.URL + `/token","jwks_uri":"` + server.URL + `/jwks"}`)) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + return server +} + +func (ts *TokenOIDCTestSuite) TestGetProvider() { + server := SetupTestOIDCProvider(ts) + defer server.Close() + + params := &IdTokenGrantParams{ + IdToken: "test-id-token", + AccessToken: "test-access-token", + Nonce: "test-nonce", + Provider: server.URL, + ClientID: "test-client-id", + Issuer: server.URL, + } + + ts.Config.External.AllowedIdTokenIssuers = []string{server.URL} + + req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) + oidcProvider, skipNonceCheck, providerType, acceptableClientIds, err := params.getProvider(context.Background(), ts.Config, req) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), oidcProvider) + require.False(ts.T(), skipNonceCheck) + require.Equal(ts.T(), params.Provider, providerType) + require.NotEmpty(ts.T(), acceptableClientIds) +} diff --git a/auth_v2.169.0/internal/api/token_refresh.go b/auth_v2.169.0/internal/api/token_refresh.go new file mode 100644 index 0000000..7eae233 --- /dev/null +++ b/auth_v2.169.0/internal/api/token_refresh.go @@ -0,0 +1,274 @@ +package api + +import ( + "context" + mathRand "math/rand" + "net/http" + "time" + + "github.com/supabase/auth/internal/metering" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" +) + +const retryLoopDuration = 5.0 + +// RefreshTokenGrantParams are the parameters the RefreshTokenGrant method accepts +type RefreshTokenGrantParams struct { + RefreshToken string `json:"refresh_token"` +} + +// RefreshTokenGrant implements the refresh_token grant type flow +func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + db := a.db.WithContext(ctx) + config := a.config + + params := &RefreshTokenGrantParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + if params.RefreshToken == "" { + return oauthError("invalid_request", "refresh_token required") + } + + // A 5 second retry loop is used to make sure that refresh token + // requests do not waste database connections waiting for each other. + // Instead of waiting at the database level, they're waiting at the API + // level instead and retry to refresh the locked row every 10-30 + // milliseconds. + retryStart := a.Now() + retry := true + + for retry && time.Since(retryStart).Seconds() < retryLoopDuration { + retry = false + + user, token, session, err := models.FindUserWithRefreshToken(db, params.RefreshToken, false) + if err != nil { + if models.IsNotFoundError(err) { + return badRequestError(ErrorCodeRefreshTokenNotFound, "Invalid Refresh Token: Refresh Token Not Found") + } + return internalServerError(err.Error()) + } + + if user.IsBanned() { + return badRequestError(ErrorCodeUserBanned, "Invalid Refresh Token: User Banned") + } + + if session == nil { + // a refresh token won't have a session if it's created prior to the sessions table introduced + if err := db.Destroy(token); err != nil { + return internalServerError("Error deleting refresh token with missing session").WithInternalError(err) + } + return badRequestError(ErrorCodeSessionNotFound, "Invalid Refresh Token: No Valid Session Found") + } + + result := session.CheckValidity(retryStart, &token.UpdatedAt, config.Sessions.Timebox, config.Sessions.InactivityTimeout) + + switch result { + case models.SessionValid: + // do nothing + + case models.SessionTimedOut: + return badRequestError(ErrorCodeSessionExpired, "Invalid Refresh Token: Session Expired (Inactivity)") + + default: + return badRequestError(ErrorCodeSessionExpired, "Invalid Refresh Token: Session Expired") + } + + // Basic checks above passed, now we need to serialize access + // to the session in a transaction so that there's no + // concurrent modification. In the event that the refresh + // token's row or session is locked, the transaction is closed + // and the whole process will be retried a bit later so that + // the connection pool does not get exhausted. + + var tokenString string + var expiresAt int64 + var newTokenResponse *AccessTokenResponse + + err = db.Transaction(func(tx *storage.Connection) error { + user, token, session, terr := models.FindUserWithRefreshToken(tx, params.RefreshToken, true /* forUpdate */) + if terr != nil { + if models.IsNotFoundError(terr) { + // because forUpdate was set, and the + // previous check outside the + // transaction found a refresh token + // and session, but now we're getting a + // IsNotFoundError, this means that the + // refresh token row and session are + // probably locked so we need to retry + // in a few milliseconds. + retry = true + return terr + } + return internalServerError(terr.Error()) + } + + if a.config.Sessions.SinglePerUser { + sessions, terr := models.FindAllSessionsForUser(tx, user.ID, true /* forUpdate */) + if models.IsNotFoundError(terr) { + // because forUpdate was set, and the + // previous check outside the + // transaction found a user and + // session, but now we're getting a + // IsNotFoundError, this means that the + // user is locked and we need to retry + // in a few milliseconds + retry = true + return terr + } else if terr != nil { + return internalServerError(terr.Error()) + } + + sessionTag := session.DetermineTag(config.Sessions.Tags) + + // go through all sessions of the user and + // check if the current session is the user's + // most recently refreshed valid session + for _, s := range sessions { + if s.ID == session.ID { + // current session, skip it + continue + } + + if s.CheckValidity(retryStart, nil, config.Sessions.Timebox, config.Sessions.InactivityTimeout) != models.SessionValid { + // session is not valid so it + // can't be regarded as active + // on the user + continue + } + + if s.DetermineTag(config.Sessions.Tags) != sessionTag { + // if tags are specified, + // ignore sessions with a + // mismatching tag + continue + } + + // since token is not the refresh token + // of s, we can't use it's UpdatedAt + // time to compare! + if s.LastRefreshedAt(nil).After(session.LastRefreshedAt(&token.UpdatedAt)) { + // session is not the most + // recently active one + return badRequestError(ErrorCodeSessionExpired, "Invalid Refresh Token: Session Expired (Revoked by Newer Login)") + } + } + + // this session is the user's active session + } + + // refresh token row and session are locked at this + // point, cannot be concurrently refreshed + + var issuedToken *models.RefreshToken + + if token.Revoked { + activeRefreshToken, terr := session.FindCurrentlyActiveRefreshToken(tx) + if terr != nil && !models.IsNotFoundError(terr) { + return internalServerError(terr.Error()) + } + + if activeRefreshToken != nil && activeRefreshToken.Parent.String() == token.Token { + // Token was revoked, but it's the + // parent of the currently active one. + // This indicates that the client was + // not able to store the result when it + // refreshed token. This case is + // allowed, provided we return back the + // active refresh token instead of + // creating a new one. + issuedToken = activeRefreshToken + } else { + // For a revoked refresh token to be reused, it + // has to fall within the reuse interval. + reuseUntil := token.UpdatedAt.Add( + time.Second * time.Duration(config.Security.RefreshTokenReuseInterval)) + + if a.Now().After(reuseUntil) { + // not OK to reuse this token + if config.Security.RefreshTokenRotationEnabled { + // Revoke all tokens in token family + if err := models.RevokeTokenFamily(tx, token); err != nil { + return internalServerError(err.Error()) + } + } + + return storage.NewCommitWithError(badRequestError(ErrorCodeRefreshTokenAlreadyUsed, "Invalid Refresh Token: Already Used").WithInternalMessage("Possible abuse attempt: %v", token.ID)) + } + } + } + + if terr = models.NewAuditLogEntry(r, tx, user, models.TokenRefreshedAction, "", nil); terr != nil { + return terr + } + + if issuedToken == nil { + newToken, terr := models.GrantRefreshTokenSwap(r, tx, user, token) + if terr != nil { + return terr + } + + issuedToken = newToken + } + + tokenString, expiresAt, terr = a.generateAccessToken(r, tx, user, issuedToken.SessionId, models.TokenRefresh) + if terr != nil { + httpErr, ok := terr.(*HTTPError) + if ok { + return httpErr + } + return internalServerError("error generating jwt token").WithInternalError(terr) + } + + refreshedAt := a.Now() + session.RefreshedAt = &refreshedAt + + userAgent := r.Header.Get("User-Agent") + if userAgent != "" { + session.UserAgent = &userAgent + } else { + session.UserAgent = nil + } + + ipAddress := utilities.GetIPAddress(r) + if ipAddress != "" { + session.IP = &ipAddress + } else { + session.IP = nil + } + + if terr := session.UpdateOnlyRefreshInfo(tx); terr != nil { + return internalServerError("failed to update session information").WithInternalError(terr) + } + + newTokenResponse = &AccessTokenResponse{ + Token: tokenString, + TokenType: "bearer", + ExpiresIn: config.JWT.Exp, + ExpiresAt: expiresAt, + RefreshToken: issuedToken.Token, + User: user, + } + + return nil + }) + if err != nil { + if retry && models.IsNotFoundError(err) { + // refresh token and session row were likely locked, so + // we need to wait a moment before retrying the whole + // process anew + time.Sleep(time.Duration(10+mathRand.Intn(20)) * time.Millisecond) // #nosec + continue + } else { + return err + } + } + metering.RecordLogin("token", user.ID) + return sendJSON(w, http.StatusOK, newTokenResponse) + } + + return conflictError("Too many concurrent token refresh requests on the same session or refresh token") +} diff --git a/auth_v2.169.0/internal/api/token_test.go b/auth_v2.169.0/internal/api/token_test.go new file mode 100644 index 0000000..fc89d4f --- /dev/null +++ b/auth_v2.169.0/internal/api/token_test.go @@ -0,0 +1,857 @@ +package api + +import ( + "bytes" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "os" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +type TokenTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration + + RefreshToken *models.RefreshToken + User *models.User +} + +func TestToken(t *testing.T) { + os.Setenv("GOTRUE_RATE_LIMIT_HEADER", "My-Custom-Header") + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &TokenTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *TokenTestSuite) SetupTest() { + ts.RefreshToken = nil + models.TruncateAll(ts.API.db) + + // Create user & refresh token + u, err := models.NewUser("", "test@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + t := time.Now() + u.EmailConfirmedAt = &t + u.BannedUntil = nil + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") + + ts.User = u + ts.RefreshToken, err = models.GrantAuthenticatedUser(ts.API.db, u, models.GrantParams{}) + require.NoError(ts.T(), err, "Error creating refresh token") + ts.Config.Hook.CustomAccessToken.Enabled = false + +} + +func (ts *TokenTestSuite) TestSessionTimebox() { + timebox := 10 * time.Second + + ts.API.config.Sessions.Timebox = &timebox + ts.API.overrideTime = func() time.Time { + return time.Now().Add(timebox).Add(time.Second) + } + + defer func() { + ts.API.overrideTime = nil + ts.API.config.Sessions.Timebox = nil + }() + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": ts.RefreshToken.Token, + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + + var firstResult struct { + ErrorCode string `json:"error_code"` + Message string `json:"msg"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + assert.Equal(ts.T(), ErrorCodeSessionExpired, firstResult.ErrorCode) + assert.Equal(ts.T(), "Invalid Refresh Token: Session Expired", firstResult.Message) +} + +func (ts *TokenTestSuite) TestSessionInactivityTimeout() { + inactivityTimeout := 10 * time.Second + + ts.API.config.Sessions.InactivityTimeout = &inactivityTimeout + ts.API.overrideTime = func() time.Time { + return time.Now().Add(inactivityTimeout).Add(time.Second) + } + + defer func() { + ts.API.config.Sessions.InactivityTimeout = nil + ts.API.overrideTime = nil + }() + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": ts.RefreshToken.Token, + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + + var firstResult struct { + ErrorCode string `json:"error_code"` + Message string `json:"msg"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + assert.Equal(ts.T(), ErrorCodeSessionExpired, firstResult.ErrorCode) + assert.Equal(ts.T(), "Invalid Refresh Token: Session Expired (Inactivity)", firstResult.Message) +} + +func (ts *TokenTestSuite) TestFailedToSaveRefreshTokenResultCase() { + var buffer bytes.Buffer + + // first refresh + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": ts.RefreshToken.Token, + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + var firstResult struct { + RefreshToken string `json:"refresh_token"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + assert.NotEmpty(ts.T(), firstResult.RefreshToken) + + // pretend that the browser wasn't able to save the firstResult, + // run again with the first refresh token + buffer = bytes.Buffer{} + + // second refresh with the reused refresh token + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": ts.RefreshToken.Token, + })) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + var secondResult struct { + RefreshToken string `json:"refresh_token"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&secondResult)) + assert.NotEmpty(ts.T(), secondResult.RefreshToken) + + // new refresh token is not being issued but the active one from + // the first refresh that failed to save is stored + assert.Equal(ts.T(), firstResult.RefreshToken, secondResult.RefreshToken) +} + +func (ts *TokenTestSuite) TestSingleSessionPerUserNoTags() { + ts.API.config.Sessions.SinglePerUser = true + defer func() { + ts.API.config.Sessions.SinglePerUser = false + }() + + firstRefreshToken := ts.RefreshToken + + // just in case to give some delay between first and second session creation + time.Sleep(10 * time.Millisecond) + + secondRefreshToken, err := models.GrantAuthenticatedUser(ts.API.db, ts.User, models.GrantParams{}) + + require.NoError(ts.T(), err) + + require.NotEqual(ts.T(), *firstRefreshToken.SessionId, *secondRefreshToken.SessionId) + require.Equal(ts.T(), firstRefreshToken.UserID, secondRefreshToken.UserID) + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": firstRefreshToken.Token, + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + assert.True(ts.T(), ts.API.config.Sessions.SinglePerUser) + + var firstResult struct { + ErrorCode string `json:"error_code"` + Message string `json:"msg"` + } + + assert.NoError(ts.T(), json.NewDecoder(w.Result().Body).Decode(&firstResult)) + assert.Equal(ts.T(), ErrorCodeSessionExpired, firstResult.ErrorCode) + assert.Equal(ts.T(), "Invalid Refresh Token: Session Expired (Revoked by Newer Login)", firstResult.Message) +} + +func (ts *TokenTestSuite) TestRateLimitTokenRefresh() { + var buffer bytes.Buffer + req := httptest.NewRequest(http.MethodPost, "http://localhost/token", &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("My-Custom-Header", "1.2.3.4") + + // It rate limits after 30 requests + for i := 0; i < 30; i++ { + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + } + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusTooManyRequests, w.Code) + + // It ignores X-Forwarded-For by default + req.Header.Set("X-Forwarded-For", "1.1.1.1") + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusTooManyRequests, w.Code) + + // It doesn't rate limit a new value for the limited header + req = httptest.NewRequest(http.MethodPost, "http://localhost/token", &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("My-Custom-Header", "5.6.7.8") + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) +} + +func (ts *TokenTestSuite) TestTokenPasswordGrantSuccess() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + "password": "password", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) +} + +func (ts *TokenTestSuite) TestTokenRefreshTokenGrantSuccess() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": ts.RefreshToken.Token, + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) +} + +func (ts *TokenTestSuite) TestTokenPasswordGrantFailure() { + u := ts.createBannedUser() + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": u.GetEmail(), + "password": "password", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) +} + +func (ts *TokenTestSuite) TestTokenPKCEGrantFailure() { + authCode := "1234563" + codeVerifier := "4a9505b9-0857-42bb-ab3c-098b4d28ddc2" + invalidAuthCode := authCode + "123" + invalidVerifier := codeVerifier + "123" + codeChallenge := sha256.Sum256([]byte(codeVerifier)) + challenge := base64.RawURLEncoding.EncodeToString(codeChallenge[:]) + flowState := models.NewFlowState("github", challenge, models.SHA256, models.OAuth, nil) + flowState.AuthCode = authCode + require.NoError(ts.T(), ts.API.db.Create(flowState)) + cases := []struct { + desc string + authCode string + codeVerifier string + grantType string + expectedHTTPCode int + }{ + { + desc: "Invalid Authcode", + authCode: invalidAuthCode, + codeVerifier: codeVerifier, + }, + { + desc: "Invalid code verifier", + authCode: authCode, + codeVerifier: invalidVerifier, + }, + { + desc: "Invalid auth code and verifier", + authCode: invalidAuthCode, + codeVerifier: invalidVerifier, + }, + } + for _, v := range cases { + ts.Run(v.desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "code_verifier": v.codeVerifier, + "auth_code": v.authCode, + })) + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=pkce", &buffer) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusNotFound, w.Code) + }) + } +} + +func (ts *TokenTestSuite) TestTokenRefreshTokenGrantFailure() { + _ = ts.createBannedUser() + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": ts.RefreshToken.Token, + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) +} + +func (ts *TokenTestSuite) TestRefreshTokenReuseRevocation() { + originalSecurity := ts.API.config.Security + + ts.API.config.Security.RefreshTokenRotationEnabled = true + ts.API.config.Security.RefreshTokenReuseInterval = 0 + + defer func() { + ts.API.config.Security = originalSecurity + }() + + refreshTokens := []string{ + ts.RefreshToken.Token, + } + + for i := 0; i < 3; i += 1 { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": refreshTokens[len(refreshTokens)-1], + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusOK, w.Code) + + var response struct { + RefreshToken string `json:"refresh_token"` + } + + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response)) + + refreshTokens = append(refreshTokens, response.RefreshToken) + } + + // ensure that the 4 refresh tokens are setup correctly + for i, refreshToken := range refreshTokens { + _, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false) + require.NoError(ts.T(), err) + + if i == len(refreshTokens)-1 { + require.False(ts.T(), token.Revoked) + } else { + require.True(ts.T(), token.Revoked) + } + } + + // try to reuse the first (earliest) refresh token which should trigger the family revocation logic + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": refreshTokens[0], + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) + + var response struct { + ErrorCode string `json:"error_code"` + Message string `json:"msg"` + } + + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response)) + require.Equal(ts.T(), ErrorCodeRefreshTokenAlreadyUsed, response.ErrorCode) + require.Equal(ts.T(), "Invalid Refresh Token: Already Used", response.Message) + + // ensure that the refresh tokens are marked as revoked in the database + for _, refreshToken := range refreshTokens { + _, token, _, err := models.FindUserWithRefreshToken(ts.API.db, refreshToken, false) + require.NoError(ts.T(), err) + + require.True(ts.T(), token.Revoked) + } + + // finally ensure that none of the refresh tokens can be reused any + // more, starting with the previously valid one + for i := len(refreshTokens) - 1; i >= 0; i -= 1 { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": refreshTokens[i], + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusBadRequest, w.Code, "For refresh token %d", i) + + var response struct { + ErrorCode string `json:"error_code"` + Message string `json:"msg"` + } + + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&response)) + require.Equal(ts.T(), ErrorCodeRefreshTokenAlreadyUsed, response.ErrorCode, "For refresh token %d", i) + require.Equal(ts.T(), "Invalid Refresh Token: Already Used", response.Message, "For refresh token %d", i) + } +} + +func (ts *TokenTestSuite) createBannedUser() *models.User { + u, err := models.NewUser("", "banned@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + t := time.Now() + u.EmailConfirmedAt = &t + t = t.Add(24 * time.Hour) + u.BannedUntil = &t + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test banned user") + + ts.RefreshToken, err = models.GrantAuthenticatedUser(ts.API.db, u, models.GrantParams{}) + require.NoError(ts.T(), err, "Error creating refresh token") + + return u +} + +func (ts *TokenTestSuite) TestTokenRefreshWithExpiredSession() { + var err error + + now := time.Now().UTC().Add(-1 * time.Second) + + ts.RefreshToken, err = models.GrantAuthenticatedUser(ts.API.db, ts.User, models.GrantParams{ + SessionNotAfter: &now, + }) + require.NoError(ts.T(), err, "Error creating refresh token") + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": ts.RefreshToken.Token, + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusBadRequest, w.Code) +} + +func (ts *TokenTestSuite) TestTokenRefreshWithUnexpiredSession() { + var err error + + now := time.Now().UTC().Add(1 * time.Second) + + ts.RefreshToken, err = models.GrantAuthenticatedUser(ts.API.db, ts.User, models.GrantParams{ + SessionNotAfter: &now, + }) + require.NoError(ts.T(), err, "Error creating refresh token") + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": ts.RefreshToken.Token, + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) +} + +func (ts *TokenTestSuite) TestMagicLinkPKCESignIn() { + var buffer bytes.Buffer + // Send OTP + codeVerifier := "4a9505b9-0857-42bb-ab3c-098b4d28ddc2" + codeChallenge := sha256.Sum256([]byte(codeVerifier)) + challenge := base64.RawURLEncoding.EncodeToString(codeChallenge[:]) + + req := httptest.NewRequest(http.MethodPost, "/otp", &buffer) + req.Header.Set("Content-Type", "application/json") + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(OtpParams{ + Email: "test@example.com", + CreateUser: true, + CodeChallengeMethod: "s256", + CodeChallenge: challenge, + })) + req = httptest.NewRequest(http.MethodPost, "/otp", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + // Verify OTP + requestUrl := fmt.Sprintf("http://localhost/verify?type=%v&token=%v", "magiclink", u.RecoveryToken) + req = httptest.NewRequest(http.MethodGet, requestUrl, &buffer) + req.Header.Set("Content-Type", "application/json") + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusSeeOther, w.Code) + rURL, _ := w.Result().Location() + + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + assert.True(ts.T(), u.IsConfirmed()) + + f, err := url.ParseQuery(rURL.RawQuery) + require.NoError(ts.T(), err) + authCode := f.Get("code") + assert.NotEmpty(ts.T(), authCode) + // Extract token and sign in + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "code_verifier": codeVerifier, + "auth_code": authCode, + })) + req = httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=pkce", &buffer) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + verifyResp := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&verifyResp)) + require.NotEmpty(ts.T(), verifyResp.Token) + +} + +func (ts *TokenTestSuite) TestPasswordVerificationHook() { + type verificationHookTestcase struct { + desc string + uri string + hookFunctionSQL string + expectedCode int + } + cases := []verificationHookTestcase{ + { + desc: "Default success", + uri: "pg-functions://postgres/auth/password_verification_hook", + hookFunctionSQL: ` + create or replace function password_verification_hook(input jsonb) + returns jsonb as $$ + begin + return jsonb_build_object('decision', 'continue'); + end; $$ language plpgsql;`, + expectedCode: http.StatusOK, + }, { + desc: "Reject- Enabled", + uri: "pg-functions://postgres/auth/password_verification_hook_reject", + hookFunctionSQL: ` + create or replace function password_verification_hook_reject(input jsonb) + returns jsonb as $$ + begin + return jsonb_build_object('decision', 'reject', 'message', 'You shall not pass!'); + end; $$ language plpgsql;`, + expectedCode: http.StatusBadRequest, + }, + } + for _, c := range cases { + ts.T().Run(c.desc, func(t *testing.T) { + ts.Config.Hook.PasswordVerificationAttempt.Enabled = true + ts.Config.Hook.PasswordVerificationAttempt.URI = c.uri + require.NoError(ts.T(), ts.Config.Hook.PasswordVerificationAttempt.PopulateExtensibilityPoint()) + + err := ts.API.db.RawQuery(c.hookFunctionSQL).Exec() + require.NoError(t, err) + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + "password": "password", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), c.expectedCode, w.Code) + cleanupHookSQL := fmt.Sprintf("drop function if exists %s", ts.Config.Hook.PasswordVerificationAttempt.HookName) + require.NoError(ts.T(), ts.API.db.RawQuery(cleanupHookSQL).Exec()) + // Reset so it doesn't affect other tests + ts.Config.Hook.PasswordVerificationAttempt.Enabled = false + + }) + } + +} + +func (ts *TokenTestSuite) TestCustomAccessToken() { + type customAccessTokenTestcase struct { + desc string + uri string + hookFunctionSQL string + expectedClaims map[string]interface{} + shouldError bool + } + cases := []customAccessTokenTestcase{ + { + desc: "Add a new claim", + uri: "pg-functions://postgres/auth/custom_access_token_add_claim", + hookFunctionSQL: ` create or replace function custom_access_token_add_claim(input jsonb) returns jsonb as $$ declare result jsonb; begin if jsonb_typeof(jsonb_object_field(input, 'claims')) is null then result := jsonb_build_object('error', jsonb_build_object('http_code', 400, 'message', 'Input does not contain claims field')); return result; end if; + input := jsonb_set(input, '{claims,newclaim}', '"newvalue"', true); + result := jsonb_build_object('claims', input->'claims'); + return result; +end; $$ language plpgsql;`, + expectedClaims: map[string]interface{}{ + "newclaim": "newvalue", + }, + }, { + desc: "Delete the Role claim", + uri: "pg-functions://postgres/auth/custom_access_token_delete_claim", + hookFunctionSQL: ` +create or replace function custom_access_token_delete_claim(input jsonb) +returns jsonb as $$ +declare + result jsonb; +begin + input := jsonb_set(input, '{claims}', (input->'claims') - 'role'); + result := jsonb_build_object('claims', input->'claims'); + return result; +end; $$ language plpgsql;`, + expectedClaims: map[string]interface{}{}, + shouldError: true, + }, { + desc: "Delete a non-required claim (UserMetadata)", + uri: "pg-functions://postgres/auth/custom_access_token_delete_usermetadata", + hookFunctionSQL: ` +create or replace function custom_access_token_delete_usermetadata(input jsonb) +returns jsonb as $$ +declare + result jsonb; +begin + input := jsonb_set(input, '{claims}', (input->'claims') - 'user_metadata'); + result := jsonb_build_object('claims', input->'claims'); + return result; +end; $$ language plpgsql;`, + // Not used + expectedClaims: map[string]interface{}{ + "user_metadata": nil, + }, + shouldError: false, + }, + } + for _, c := range cases { + ts.T().Run(c.desc, func(t *testing.T) { + ts.Config.Hook.CustomAccessToken.Enabled = true + ts.Config.Hook.CustomAccessToken.URI = c.uri + require.NoError(t, ts.Config.Hook.CustomAccessToken.PopulateExtensibilityPoint()) + + err := ts.API.db.RawQuery(c.hookFunctionSQL).Exec() + require.NoError(t, err) + + var buffer bytes.Buffer + require.NoError(t, json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": ts.RefreshToken.Token, + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + var tokenResponse struct { + AccessToken string `json:"access_token"` + } + require.NoError(t, json.NewDecoder(w.Result().Body).Decode(&tokenResponse)) + if c.shouldError { + require.Equal(t, http.StatusInternalServerError, w.Code) + } else { + parts := strings.Split(tokenResponse.AccessToken, ".") + require.Equal(t, 3, len(parts), "Token should have 3 parts") + + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + require.NoError(t, err) + + var responseClaims map[string]interface{} + require.NoError(t, json.Unmarshal(payload, &responseClaims)) + + for key, expectedValue := range c.expectedClaims { + if expectedValue == nil { + // Since c.shouldError is false here, we only need to check if the claim should be removed + _, exists := responseClaims[key] + assert.False(t, exists, "Claim should be removed") + } else { + assert.Equal(t, expectedValue, responseClaims[key]) + } + } + } + + cleanupHookSQL := fmt.Sprintf("drop function if exists %s", ts.Config.Hook.CustomAccessToken.HookName) + require.NoError(t, ts.API.db.RawQuery(cleanupHookSQL).Exec()) + ts.Config.Hook.CustomAccessToken.Enabled = false + }) + } +} + +func (ts *TokenTestSuite) TestAllowSelectAuthenticationMethods() { + + companyUser, err := models.NewUser("12345678", "test@company.com", "password", ts.Config.JWT.Aud, nil) + t := time.Now() + companyUser.EmailConfirmedAt = &t + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(companyUser), "Error saving new test user") + + type allowSelectAuthMethodsTestcase struct { + desc string + uri string + email string + expectedError string + expectedStatus int + } + + // Common hook function SQL definition + hookFunctionSQL := ` +create or replace function auth.custom_access_token(event jsonb) returns jsonb language plpgsql as $$ +declare + email_claim text; + authentication_method text; +begin + email_claim := event->'claims'->>'email'; + authentication_method := event->>'authentication_method'; + + if authentication_method = 'password' and email_claim not like '%@company.com' then + return jsonb_build_object( + 'error', jsonb_build_object( + 'http_code', 403, + 'message', 'only members on company.com can access with password authentication' + ) + ); + end if; + + return event; +end; +$$;` + + cases := []allowSelectAuthMethodsTestcase{ + { + desc: "Error for non-protected domain with password authentication", + uri: "pg-functions://postgres/auth/custom_access_token", + email: "test@example.com", + expectedError: "only members on company.com can access with password authentication", + expectedStatus: http.StatusForbidden, + }, + { + desc: "Allow access for protected domain with password authentication", + uri: "pg-functions://postgres/auth/custom_access_token", + email: companyUser.Email.String(), + expectedError: "", + expectedStatus: http.StatusOK, + }, + } + + for _, c := range cases { + ts.T().Run(c.desc, func(t *testing.T) { + // Enable and set up the custom access token hook + ts.Config.Hook.CustomAccessToken.Enabled = true + ts.Config.Hook.CustomAccessToken.URI = c.uri + require.NoError(t, ts.Config.Hook.CustomAccessToken.PopulateExtensibilityPoint()) + + // Execute the common hook function SQL + err := ts.API.db.RawQuery(hookFunctionSQL).Exec() + require.NoError(t, err) + + var buffer bytes.Buffer + + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": c.email, + "password": "password", + })) + + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(t, c.expectedStatus, w.Code, "Unexpected HTTP status code") + if c.expectedError != "" { + require.Contains(t, w.Body.String(), c.expectedError, "Expected error message not found") + } else { + require.NotContains(t, w.Body.String(), "error", "Unexpected error occurred") + } + + // Delete the function and cleanup + cleanupHookSQL := fmt.Sprintf("drop function if exists %s", ts.Config.Hook.CustomAccessToken.HookName) + require.NoError(t, ts.API.db.RawQuery(cleanupHookSQL).Exec()) + ts.Config.Hook.CustomAccessToken.Enabled = false + }) + } +} diff --git a/auth_v2.169.0/internal/api/user.go b/auth_v2.169.0/internal/api/user.go new file mode 100644 index 0000000..8588ce3 --- /dev/null +++ b/auth_v2.169.0/internal/api/user.go @@ -0,0 +1,266 @@ +package api + +import ( + "context" + "net/http" + "time" + + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/api/sms_provider" + "github.com/supabase/auth/internal/mailer" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +// UserUpdateParams parameters for updating a user +type UserUpdateParams struct { + Email string `json:"email"` + Password *string `json:"password"` + Nonce string `json:"nonce"` + Data map[string]interface{} `json:"data"` + AppData map[string]interface{} `json:"app_metadata,omitempty"` + Phone string `json:"phone"` + Channel string `json:"channel"` + CodeChallenge string `json:"code_challenge"` + CodeChallengeMethod string `json:"code_challenge_method"` +} + +func (a *API) validateUserUpdateParams(ctx context.Context, p *UserUpdateParams) error { + config := a.config + + var err error + if p.Email != "" { + p.Email, err = a.validateEmail(p.Email) + if err != nil { + return err + } + } + + if p.Phone != "" { + if p.Phone, err = validatePhone(p.Phone); err != nil { + return err + } + if p.Channel == "" { + p.Channel = sms_provider.SMSProvider + } + if !sms_provider.IsValidMessageChannel(p.Channel, config) { + return badRequestError(ErrorCodeValidationFailed, InvalidChannelError) + } + } + + if p.Password != nil { + if err := a.checkPasswordStrength(ctx, *p.Password); err != nil { + return err + } + } + + return nil +} + +// UserGet returns a user +func (a *API) UserGet(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + claims := getClaims(ctx) + if claims == nil { + return internalServerError("Could not read claims") + } + + aud := a.requestAud(ctx, r) + audienceFromClaims, _ := claims.GetAudience() + if len(audienceFromClaims) == 0 || aud != audienceFromClaims[0] { + return badRequestError(ErrorCodeValidationFailed, "Token audience doesn't match request audience") + } + + user := getUser(ctx) + return sendJSON(w, http.StatusOK, user) +} + +// UserUpdate updates fields on a user +func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + config := a.config + aud := a.requestAud(ctx, r) + + params := &UserUpdateParams{} + if err := retrieveRequestParams(r, params); err != nil { + return err + } + + user := getUser(ctx) + session := getSession(ctx) + + if err := a.validateUserUpdateParams(ctx, params); err != nil { + return err + } + + if params.AppData != nil && !isAdmin(user, config) { + if !isAdmin(user, config) { + return forbiddenError(ErrorCodeNotAdmin, "Updating app_metadata requires admin privileges") + } + } + + if user.HasMFAEnabled() && !session.IsAAL2() { + if (params.Password != nil && *params.Password != "") || (params.Email != "" && user.GetEmail() != params.Email) || (params.Phone != "" && user.GetPhone() != params.Phone) { + return httpError(http.StatusUnauthorized, ErrorCodeInsufficientAAL, "AAL2 session is required to update email or password when MFA is enabled.") + } + } + + if user.IsAnonymous { + if params.Password != nil && *params.Password != "" { + if params.Email == "" && params.Phone == "" { + return unprocessableEntityError(ErrorCodeValidationFailed, "Updating password of an anonymous user without an email or phone is not allowed") + } + } + } + + if user.IsSSOUser { + updatingForbiddenFields := false + + updatingForbiddenFields = updatingForbiddenFields || (params.Password != nil && *params.Password != "") + updatingForbiddenFields = updatingForbiddenFields || (params.Email != "" && params.Email != user.GetEmail()) + updatingForbiddenFields = updatingForbiddenFields || (params.Phone != "" && params.Phone != user.GetPhone()) + updatingForbiddenFields = updatingForbiddenFields || (params.Nonce != "") + + if updatingForbiddenFields { + return unprocessableEntityError(ErrorCodeUserSSOManaged, "Updating email, phone, password of a SSO account only possible via SSO") + } + } + + if params.Email != "" && user.GetEmail() != params.Email { + if duplicateUser, err := models.IsDuplicatedEmail(db, params.Email, aud, user); err != nil { + return internalServerError("Database error checking email").WithInternalError(err) + } else if duplicateUser != nil { + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) + } + } + + if params.Phone != "" && user.GetPhone() != params.Phone { + if exists, err := models.IsDuplicatedPhone(db, params.Phone, aud); err != nil { + return internalServerError("Database error checking phone").WithInternalError(err) + } else if exists { + return unprocessableEntityError(ErrorCodePhoneExists, DuplicatePhoneMsg) + } + } + + if params.Password != nil { + if config.Security.UpdatePasswordRequireReauthentication { + now := time.Now() + // we require reauthentication if the user hasn't signed in recently in the current session + if session == nil || now.After(session.CreatedAt.Add(24*time.Hour)) { + if len(params.Nonce) == 0 { + return badRequestError(ErrorCodeReauthenticationNeeded, "Password update requires reauthentication") + } + if err := a.verifyReauthentication(params.Nonce, db, config, user); err != nil { + return err + } + } + } + + password := *params.Password + if password != "" { + isSamePassword := false + + if user.HasPassword() { + auth, _, err := user.Authenticate(ctx, db, password, config.Security.DBEncryption.DecryptionKeys, false, "") + if err != nil { + return err + } + + isSamePassword = auth + } + + if isSamePassword { + return unprocessableEntityError(ErrorCodeSamePassword, "New password should be different from the old password.") + } + } + + if err := user.SetPassword(ctx, password, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey); err != nil { + return err + } + } + + err := db.Transaction(func(tx *storage.Connection) error { + var terr error + if params.Password != nil { + var sessionID *uuid.UUID + if session != nil { + sessionID = &session.ID + } + + if terr = user.UpdatePassword(tx, sessionID); terr != nil { + return internalServerError("Error during password storage").WithInternalError(terr) + } + + if terr := models.NewAuditLogEntry(r, tx, user, models.UserUpdatePasswordAction, "", nil); terr != nil { + return terr + } + } + + if params.Data != nil { + if terr = user.UpdateUserMetaData(tx, params.Data); terr != nil { + return internalServerError("Error updating user").WithInternalError(terr) + } + } + + if params.AppData != nil { + if terr = user.UpdateAppMetaData(tx, params.AppData); terr != nil { + return internalServerError("Error updating user").WithInternalError(terr) + } + } + + if params.Email != "" && params.Email != user.GetEmail() { + if user.IsAnonymous && config.Mailer.Autoconfirm { + // anonymous users can add an email with automatic confirmation, which is similar to signing up + // permanent users always need to verify their email address when changing it + user.EmailChange = params.Email + if _, terr := a.emailChangeVerify(r, tx, &VerifyParams{ + Type: mailer.EmailChangeVerification, + Email: params.Email, + }, user); terr != nil { + return terr + } + + } else { + flowType := getFlowFromChallenge(params.CodeChallenge) + if isPKCEFlow(flowType) { + _, terr := generateFlowState(tx, models.EmailChange.String(), models.EmailChange, params.CodeChallengeMethod, params.CodeChallenge, &user.ID) + if terr != nil { + return terr + } + + } + if terr = a.sendEmailChange(r, tx, user, params.Email, flowType); terr != nil { + return terr + } + } + } + + if params.Phone != "" && params.Phone != user.GetPhone() { + if config.Sms.Autoconfirm { + user.PhoneChange = params.Phone + if _, terr := a.smsVerify(r, tx, user, &VerifyParams{ + Type: phoneChangeVerification, + Phone: params.Phone, + }); terr != nil { + return terr + } + } else { + if _, terr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneChangeVerification, params.Channel); terr != nil { + return terr + } + } + } + + if terr = models.NewAuditLogEntry(r, tx, user, models.UserModifiedAction, "", nil); terr != nil { + return internalServerError("Error recording audit log entry").WithInternalError(terr) + } + + return nil + }) + if err != nil { + return err + } + + return sendJSON(w, http.StatusOK, user) +} diff --git a/auth_v2.169.0/internal/api/user_test.go b/auth_v2.169.0/internal/api/user_test.go new file mode 100644 index 0000000..ed6c585 --- /dev/null +++ b/auth_v2.169.0/internal/api/user_test.go @@ -0,0 +1,558 @@ +package api + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/models" +) + +type UserTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestUser(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &UserTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *UserTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + // Create user + u, err := models.NewUser("123456789", "test@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") +} + +func (ts *UserTestSuite) generateToken(user *models.User, sessionId *uuid.UUID) string { + req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil) + token, _, err := ts.API.generateAccessToken(req, ts.API.db, user, sessionId, models.PasswordGrant) + require.NoError(ts.T(), err, "Error generating access token") + return token +} + +func (ts *UserTestSuite) generateAccessTokenAndSession(user *models.User) string { + session, err := models.NewSession(user.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(session)) + + req := httptest.NewRequest(http.MethodPost, "/token?grant_type=password", nil) + token, _, err := ts.API.generateAccessToken(req, ts.API.db, user, &session.ID, models.PasswordGrant) + require.NoError(ts.T(), err, "Error generating access token") + return token +} + +func (ts *UserTestSuite) TestUserGet() { + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err, "Error finding user") + token := ts.generateAccessTokenAndSession(u) + + require.NoError(ts.T(), err, "Error generating access token") + + req := httptest.NewRequest(http.MethodGet, "http://localhost/user", nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) +} + +func (ts *UserTestSuite) TestUserUpdateEmail() { + cases := []struct { + desc string + userData map[string]interface{} + isSecureEmailChangeEnabled bool + isMailerAutoconfirmEnabled bool + expectedCode int + }{ + { + desc: "User doesn't have an existing email", + userData: map[string]interface{}{ + "email": "", + "phone": "", + }, + isSecureEmailChangeEnabled: false, + isMailerAutoconfirmEnabled: false, + expectedCode: http.StatusOK, + }, + { + desc: "User doesn't have an existing email and double email confirmation required", + userData: map[string]interface{}{ + "email": "", + "phone": "234567890", + }, + isSecureEmailChangeEnabled: true, + isMailerAutoconfirmEnabled: false, + expectedCode: http.StatusOK, + }, + { + desc: "User has an existing email", + userData: map[string]interface{}{ + "email": "foo@example.com", + "phone": "", + }, + isSecureEmailChangeEnabled: false, + isMailerAutoconfirmEnabled: false, + expectedCode: http.StatusOK, + }, + { + desc: "User has an existing email and double email confirmation required", + userData: map[string]interface{}{ + "email": "bar@example.com", + "phone": "", + }, + isSecureEmailChangeEnabled: true, + isMailerAutoconfirmEnabled: false, + expectedCode: http.StatusOK, + }, + { + desc: "Update email with mailer autoconfirm enabled", + userData: map[string]interface{}{ + "email": "bar@example.com", + "phone": "", + }, + isSecureEmailChangeEnabled: true, + isMailerAutoconfirmEnabled: true, + expectedCode: http.StatusOK, + }, + { + desc: "Update email with mailer autoconfirm enabled and anonymous user", + userData: map[string]interface{}{ + "email": "bar@example.com", + "phone": "", + "is_anonymous": true, + }, + isSecureEmailChangeEnabled: true, + isMailerAutoconfirmEnabled: true, + expectedCode: http.StatusOK, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + u, err := models.NewUser("", "", "", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), u.SetEmail(ts.API.db, c.userData["email"].(string)), "Error setting user email") + require.NoError(ts.T(), u.SetPhone(ts.API.db, c.userData["phone"].(string)), "Error setting user phone") + if isAnonymous, ok := c.userData["is_anonymous"]; ok { + u.IsAnonymous = isAnonymous.(bool) + } + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving test user") + + token := ts.generateAccessTokenAndSession(u) + + require.NoError(ts.T(), err, "Error generating access token") + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "new@example.com", + })) + req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + w := httptest.NewRecorder() + ts.Config.Mailer.SecureEmailChangeEnabled = c.isSecureEmailChangeEnabled + ts.Config.Mailer.Autoconfirm = c.isMailerAutoconfirmEnabled + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), c.expectedCode, w.Code) + + var data models.User + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + if c.isMailerAutoconfirmEnabled && u.IsAnonymous { + require.Empty(ts.T(), data.EmailChange) + require.Equal(ts.T(), "new@example.com", data.GetEmail()) + require.Len(ts.T(), data.Identities, 1) + } else { + require.Equal(ts.T(), "new@example.com", data.EmailChange) + require.Len(ts.T(), data.Identities, 0) + } + + // remove user after each case + require.NoError(ts.T(), ts.API.db.Destroy(u)) + }) + } + +} +func (ts *UserTestSuite) TestUserUpdatePhoneAutoconfirmEnabled() { + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + existingUser, err := models.NewUser("22222222", "", "", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(existingUser)) + + cases := []struct { + desc string + userData map[string]string + expectedCode int + }{ + { + desc: "New phone number is the same as current phone number", + userData: map[string]string{ + "phone": "123456789", + }, + expectedCode: http.StatusOK, + }, + { + desc: "New phone number exists already", + userData: map[string]string{ + "phone": "22222222", + }, + expectedCode: http.StatusUnprocessableEntity, + }, + { + desc: "New phone number is different from current phone number", + userData: map[string]string{ + "phone": "234567890", + }, + expectedCode: http.StatusOK, + }, + } + + ts.Config.Sms.Autoconfirm = true + + for _, c := range cases { + ts.Run(c.desc, func() { + token := ts.generateAccessTokenAndSession(u) + require.NoError(ts.T(), err, "Error generating access token") + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "phone": c.userData["phone"], + })) + req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), c.expectedCode, w.Code) + + if c.expectedCode == http.StatusOK { + // check that the user response returned contains the updated phone field + data := &models.User{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.Equal(ts.T(), data.GetPhone(), c.userData["phone"]) + } + }) + } + +} + +func (ts *UserTestSuite) TestUserUpdatePassword() { + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + r, err := models.GrantAuthenticatedUser(ts.API.db, u, models.GrantParams{}) + require.NoError(ts.T(), err) + + r2, err := models.GrantAuthenticatedUser(ts.API.db, u, models.GrantParams{}) + require.NoError(ts.T(), err) + + // create a session and modify it's created_at time to simulate a session that is not recently logged in + notRecentlyLoggedIn, err := models.FindSessionByID(ts.API.db, *r2.SessionId, true) + require.NoError(ts.T(), err) + + // cannot use Update here because Update doesn't removes the created_at field + require.NoError(ts.T(), ts.API.db.RawQuery( + "update "+notRecentlyLoggedIn.TableName()+" set created_at = ? where id = ?", + time.Now().Add(-24*time.Hour), + notRecentlyLoggedIn.ID).Exec(), + ) + + type expected struct { + code int + isAuthenticated bool + } + + var cases = []struct { + desc string + newPassword string + nonce string + requireReauthentication bool + sessionId *uuid.UUID + expected expected + }{ + { + desc: "Need reauthentication because outside of recently logged in window", + newPassword: "newpassword123", + nonce: "", + requireReauthentication: true, + sessionId: ¬RecentlyLoggedIn.ID, + expected: expected{code: http.StatusBadRequest, isAuthenticated: false}, + }, + { + desc: "No nonce provided", + newPassword: "newpassword123", + nonce: "", + sessionId: ¬RecentlyLoggedIn.ID, + requireReauthentication: true, + expected: expected{code: http.StatusBadRequest, isAuthenticated: false}, + }, + { + desc: "Invalid nonce", + newPassword: "newpassword1234", + nonce: "123456", + sessionId: ¬RecentlyLoggedIn.ID, + requireReauthentication: true, + expected: expected{code: http.StatusUnprocessableEntity, isAuthenticated: false}, + }, + { + desc: "No need reauthentication because recently logged in", + newPassword: "newpassword123", + nonce: "", + requireReauthentication: true, + sessionId: r.SessionId, + expected: expected{code: http.StatusOK, isAuthenticated: true}, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + ts.Config.Security.UpdatePasswordRequireReauthentication = c.requireReauthentication + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]string{"password": c.newPassword, "nonce": c.nonce})) + + req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) + req.Header.Set("Content-Type", "application/json") + token := ts.generateToken(u, c.sessionId) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), c.expected.code, w.Code) + + // Request body + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + isAuthenticated, _, err := u.Authenticate(context.Background(), ts.API.db, c.newPassword, ts.API.config.Security.DBEncryption.DecryptionKeys, ts.API.config.Security.DBEncryption.Encrypt, ts.API.config.Security.DBEncryption.EncryptionKeyID) + require.NoError(ts.T(), err) + + require.Equal(ts.T(), c.expected.isAuthenticated, isAuthenticated) + }) + } +} + +func (ts *UserTestSuite) TestUserUpdatePasswordNoReauthenticationRequired() { + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + type expected struct { + code int + isAuthenticated bool + } + + var cases = []struct { + desc string + newPassword string + nonce string + requireReauthentication bool + expected expected + }{ + { + desc: "Invalid password length", + newPassword: "", + nonce: "", + requireReauthentication: false, + expected: expected{code: http.StatusUnprocessableEntity, isAuthenticated: false}, + }, + + { + desc: "Valid password length", + newPassword: "newpassword", + nonce: "", + requireReauthentication: false, + expected: expected{code: http.StatusOK, isAuthenticated: true}, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + ts.Config.Security.UpdatePasswordRequireReauthentication = c.requireReauthentication + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]string{"password": c.newPassword, "nonce": c.nonce})) + + req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) + req.Header.Set("Content-Type", "application/json") + token := ts.generateAccessTokenAndSession(u) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), c.expected.code, w.Code) + + // Request body + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + isAuthenticated, _, err := u.Authenticate(context.Background(), ts.API.db, c.newPassword, ts.API.config.Security.DBEncryption.DecryptionKeys, ts.API.config.Security.DBEncryption.Encrypt, ts.API.config.Security.DBEncryption.EncryptionKeyID) + require.NoError(ts.T(), err) + + require.Equal(ts.T(), c.expected.isAuthenticated, isAuthenticated) + }) + } +} + +func (ts *UserTestSuite) TestUserUpdatePasswordReauthentication() { + ts.Config.Security.UpdatePasswordRequireReauthentication = true + + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + // Confirm the test user + now := time.Now() + u.EmailConfirmedAt = &now + require.NoError(ts.T(), ts.API.db.Update(u), "Error updating new test user") + + token := ts.generateAccessTokenAndSession(u) + + // request for reauthentication nonce + req := httptest.NewRequest(http.MethodGet, "http://localhost/reauthenticate", nil) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), w.Code, http.StatusOK) + + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + require.NotEmpty(ts.T(), u.ReauthenticationToken) + require.NotEmpty(ts.T(), u.ReauthenticationSentAt) + + // update reauthentication token to a known token + u.ReauthenticationToken = crypto.GenerateTokenHash(u.GetEmail(), "123456") + require.NoError(ts.T(), ts.API.db.Update(u)) + + // update password with reauthentication token + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "password": "newpass", + "nonce": "123456", + })) + + req = httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) + req.Header.Set("Content-Type", "application/json") + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), w.Code, http.StatusOK) + + // Request body + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + isAuthenticated, _, err := u.Authenticate(context.Background(), ts.API.db, "newpass", ts.Config.Security.DBEncryption.DecryptionKeys, ts.Config.Security.DBEncryption.Encrypt, ts.Config.Security.DBEncryption.EncryptionKeyID) + require.NoError(ts.T(), err) + + require.True(ts.T(), isAuthenticated) + require.Empty(ts.T(), u.ReauthenticationToken) + require.Nil(ts.T(), u.ReauthenticationSentAt) +} + +func (ts *UserTestSuite) TestUserUpdatePasswordLogoutOtherSessions() { + ts.Config.Security.UpdatePasswordRequireReauthentication = false + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + // Confirm the test user + now := time.Now() + u.EmailConfirmedAt = &now + require.NoError(ts.T(), ts.API.db.Update(u), "Error updating new test user") + + // Login the test user to get first session + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": u.GetEmail(), + "password": "password", + })) + req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + session1 := AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&session1)) + + // Login test user to get second session + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": u.GetEmail(), + "password": "password", + })) + req = httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer) + req.Header.Set("Content-Type", "application/json") + + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + session2 := AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&session2)) + + // Update user's password using first session + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "password": "newpass", + })) + + req = httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) + req.Header.Set("Content-Type", "application/json") + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", session1.Token)) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + // Attempt to refresh session1 should pass + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": session1.RefreshToken, + })) + + req = httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusOK, w.Code) + + // Attempt to refresh session2 should fail + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "refresh_token": session2.RefreshToken, + })) + + req = httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer) + req.Header.Set("Content-Type", "application/json") + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.NotEqual(ts.T(), http.StatusOK, w.Code) +} diff --git a/auth_v2.169.0/internal/api/verify.go b/auth_v2.169.0/internal/api/verify.go new file mode 100644 index 0000000..b42f5a5 --- /dev/null +++ b/auth_v2.169.0/internal/api/verify.go @@ -0,0 +1,749 @@ +package api + +import ( + "context" + "errors" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/fatih/structs" + "github.com/sethvargo/go-password/password" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/api/sms_provider" + "github.com/supabase/auth/internal/crypto" + mail "github.com/supabase/auth/internal/mailer" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" +) + +const ( + smsVerification = "sms" + phoneChangeVerification = "phone_change" + // includes signupVerification and magicLinkVerification +) + +const ( + zeroConfirmation int = iota + singleConfirmation +) + +// Only applicable when SECURE_EMAIL_CHANGE_ENABLED +const singleConfirmationAccepted = "Confirmation link accepted. Please proceed to confirm link sent to the other email" + +// VerifyParams are the parameters the Verify endpoint accepts +type VerifyParams struct { + Type string `json:"type"` + Token string `json:"token"` + TokenHash string `json:"token_hash"` + Email string `json:"email"` + Phone string `json:"phone"` + RedirectTo string `json:"redirect_to"` +} + +func (p *VerifyParams) Validate(r *http.Request, a *API) error { + var err error + if p.Type == "" { + return badRequestError(ErrorCodeValidationFailed, "Verify requires a verification type") + } + switch r.Method { + case http.MethodGet: + if p.Token == "" { + return badRequestError(ErrorCodeValidationFailed, "Verify requires a token or a token hash") + } + // TODO: deprecate the token query param from GET /verify and use token_hash instead (breaking change) + p.TokenHash = p.Token + case http.MethodPost: + if (p.Token == "" && p.TokenHash == "") || (p.Token != "" && p.TokenHash != "") { + return badRequestError(ErrorCodeValidationFailed, "Verify requires either a token or a token hash") + } + if p.Token != "" { + if isPhoneOtpVerification(p) { + p.Phone, err = validatePhone(p.Phone) + if err != nil { + return err + } + p.TokenHash = crypto.GenerateTokenHash(p.Phone, p.Token) + } else if isEmailOtpVerification(p) { + p.Email, err = a.validateEmail(p.Email) + if err != nil { + return unprocessableEntityError(ErrorCodeValidationFailed, "Invalid email format").WithInternalError(err) + } + p.TokenHash = crypto.GenerateTokenHash(p.Email, p.Token) + } else { + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided on verify") + } + } else if p.TokenHash != "" { + if p.Email != "" || p.Phone != "" || p.RedirectTo != "" { + return badRequestError(ErrorCodeValidationFailed, "Only the token_hash and type should be provided") + } + } + default: + return nil + } + return nil +} + +// Verify exchanges a confirmation or recovery token to a refresh token +func (a *API) Verify(w http.ResponseWriter, r *http.Request) error { + params := &VerifyParams{} + switch r.Method { + case http.MethodGet: + params.Token = r.FormValue("token") + params.Type = r.FormValue("type") + params.RedirectTo = utilities.GetReferrer(r, a.config) + if err := params.Validate(r, a); err != nil { + return err + } + return a.verifyGet(w, r, params) + case http.MethodPost: + if err := retrieveRequestParams(r, params); err != nil { + return err + } + if err := params.Validate(r, a); err != nil { + return err + } + return a.verifyPost(w, r, params) + default: + // this should have been handled by Chi + panic("Only GET and POST methods allowed") + } +} + +func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyParams) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + var ( + user *models.User + grantParams models.GrantParams + err error + token *AccessTokenResponse + authCode string + rurl string + ) + + grantParams.FillGrantParams(r) + + flowType := models.ImplicitFlow + var authenticationMethod models.AuthenticationMethod + if strings.HasPrefix(params.Token, PKCEPrefix) { + flowType = models.PKCEFlow + authenticationMethod, err = models.ParseAuthenticationMethod(params.Type) + if err != nil { + return err + } + } + + err = db.Transaction(func(tx *storage.Connection) error { + var terr error + user, terr = a.verifyTokenHash(tx, params) + if terr != nil { + return terr + } + switch params.Type { + case mail.SignupVerification, mail.InviteVerification: + user, terr = a.signupVerify(r, ctx, tx, user) + case mail.RecoveryVerification, mail.MagicLinkVerification: + user, terr = a.recoverVerify(r, tx, user) + case mail.EmailChangeVerification: + user, terr = a.emailChangeVerify(r, tx, params, user) + if user == nil && terr == nil { + // only one OTP is confirmed at this point, so we return early and ask the user to confirm the second OTP + rurl, terr = a.prepRedirectURL(singleConfirmationAccepted, params.RedirectTo, flowType) + if terr != nil { + return terr + } + return nil + } + default: + return badRequestError(ErrorCodeValidationFailed, "Unsupported verification type") + } + + if terr != nil { + return terr + } + + if terr := user.UpdateAppMetaDataProviders(tx); terr != nil { + return terr + } + + // Reload user model from db. + // This is important for refreshing the data in any generated columns like IsAnonymous. + if terr := tx.Reload(user); err != nil { + return terr + } + + if isImplicitFlow(flowType) { + token, terr = a.issueRefreshToken(r, tx, user, models.OTP, grantParams) + if terr != nil { + return terr + } + + } else if isPKCEFlow(flowType) { + if authCode, terr = issueAuthCode(tx, user, authenticationMethod); terr != nil { + return badRequestError(ErrorCodeFlowStateNotFound, "No associated flow state found. %s", terr) + } + } + return nil + }) + + if err != nil { + var herr *HTTPError + if errors.As(err, &herr) { + rurl, err = a.prepErrorRedirectURL(herr, r, params.RedirectTo, flowType) + if err != nil { + return err + } + } + } + if rurl != "" { + http.Redirect(w, r, rurl, http.StatusSeeOther) + return nil + } + rurl = params.RedirectTo + if isImplicitFlow(flowType) && token != nil { + q := url.Values{} + q.Set("type", params.Type) + rurl = token.AsRedirectURL(rurl, q) + } else if isPKCEFlow(flowType) { + rurl, err = a.prepPKCERedirectURL(rurl, authCode) + if err != nil { + return err + } + } + http.Redirect(w, r, rurl, http.StatusSeeOther) + return nil +} + +func (a *API) verifyPost(w http.ResponseWriter, r *http.Request, params *VerifyParams) error { + ctx := r.Context() + db := a.db.WithContext(ctx) + + var ( + user *models.User + grantParams models.GrantParams + token *AccessTokenResponse + ) + var isSingleConfirmationResponse = false + + grantParams.FillGrantParams(r) + + err := db.Transaction(func(tx *storage.Connection) error { + var terr error + aud := a.requestAud(ctx, r) + + if isUsingTokenHash(params) { + user, terr = a.verifyTokenHash(tx, params) + } else { + user, terr = a.verifyUserAndToken(tx, params, aud) + } + if terr != nil { + return terr + } + + switch params.Type { + case mail.SignupVerification, mail.InviteVerification: + user, terr = a.signupVerify(r, ctx, tx, user) + case mail.RecoveryVerification, mail.MagicLinkVerification: + user, terr = a.recoverVerify(r, tx, user) + case mail.EmailChangeVerification: + user, terr = a.emailChangeVerify(r, tx, params, user) + if user == nil && terr == nil { + isSingleConfirmationResponse = true + return nil + } + case smsVerification, phoneChangeVerification: + user, terr = a.smsVerify(r, tx, user, params) + default: + return badRequestError(ErrorCodeValidationFailed, "Unsupported verification type") + } + + if terr != nil { + return terr + } + + if terr := user.UpdateAppMetaDataProviders(tx); terr != nil { + return terr + } + + // Reload user model from db. + // This is important for refreshing the data in any generated columns like IsAnonymous. + if terr := tx.Reload(user); terr != nil { + return terr + } + token, terr = a.issueRefreshToken(r, tx, user, models.OTP, grantParams) + if terr != nil { + return terr + } + return nil + }) + if err != nil { + return err + } + if isSingleConfirmationResponse { + return sendJSON(w, http.StatusOK, map[string]string{ + "msg": singleConfirmationAccepted, + "code": strconv.Itoa(http.StatusOK), + }) + } + return sendJSON(w, http.StatusOK, token) +} + +func (a *API) signupVerify(r *http.Request, ctx context.Context, conn *storage.Connection, user *models.User) (*models.User, error) { + config := a.config + + shouldUpdatePassword := false + if !user.HasPassword() && user.InvitedAt != nil { + // sign them up with temporary password, and require application + // to present the user with a password set form + password, err := password.Generate(64, 10, 0, false, true) + if err != nil { + // password generation must succeed + panic(err) + } + + if err := user.SetPassword(ctx, password, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey); err != nil { + return nil, err + } + shouldUpdatePassword = true + } + + err := conn.Transaction(func(tx *storage.Connection) error { + var terr error + if shouldUpdatePassword { + if terr = user.UpdatePassword(tx, nil); terr != nil { + return internalServerError("Error storing password").WithInternalError(terr) + } + } + + if terr = models.NewAuditLogEntry(r, tx, user, models.UserSignedUpAction, "", nil); terr != nil { + return terr + } + + if terr = user.Confirm(tx); terr != nil { + return internalServerError("Error confirming user").WithInternalError(terr) + } + + for _, identity := range user.Identities { + if identity.Email == "" || user.Email == "" || identity.Email != user.Email { + continue + } + + if terr = identity.UpdateIdentityData(tx, map[string]interface{}{ + "email_verified": true, + }); terr != nil { + return internalServerError("Error setting email_verified to true on identity").WithInternalError(terr) + } + } + + return nil + }) + if err != nil { + return nil, err + } + return user, nil +} + +func (a *API) recoverVerify(r *http.Request, conn *storage.Connection, user *models.User) (*models.User, error) { + err := conn.Transaction(func(tx *storage.Connection) error { + var terr error + if terr = user.Recover(tx); terr != nil { + return terr + } + if !user.IsConfirmed() { + if terr = models.NewAuditLogEntry(r, tx, user, models.UserSignedUpAction, "", nil); terr != nil { + return terr + } + + if terr = user.Confirm(tx); terr != nil { + return terr + } + } else { + if terr = models.NewAuditLogEntry(r, tx, user, models.LoginAction, "", nil); terr != nil { + return terr + } + } + return nil + }) + + if err != nil { + return nil, internalServerError("Database error updating user").WithInternalError(err) + } + return user, nil +} + +func (a *API) smsVerify(r *http.Request, conn *storage.Connection, user *models.User, params *VerifyParams) (*models.User, error) { + + err := conn.Transaction(func(tx *storage.Connection) error { + + if params.Type == smsVerification { + if terr := models.NewAuditLogEntry(r, tx, user, models.UserSignedUpAction, "", nil); terr != nil { + return terr + } + if terr := user.ConfirmPhone(tx); terr != nil { + return internalServerError("Error confirming user").WithInternalError(terr) + } + } else if params.Type == phoneChangeVerification { + if terr := models.NewAuditLogEntry(r, tx, user, models.UserModifiedAction, "", nil); terr != nil { + return terr + } + if identity, terr := models.FindIdentityByIdAndProvider(tx, user.ID.String(), "phone"); terr != nil { + if !models.IsNotFoundError(terr) { + return terr + } + // confirming the phone change should create a new phone identity if the user doesn't have one + if _, terr = a.createNewIdentity(tx, user, "phone", structs.Map(provider.Claims{ + Subject: user.ID.String(), + Phone: params.Phone, + PhoneVerified: true, + })); terr != nil { + return terr + } + } else { + if terr := identity.UpdateIdentityData(tx, map[string]interface{}{ + "phone": params.Phone, + "phone_verified": true, + }); terr != nil { + return terr + } + } + if terr := user.ConfirmPhoneChange(tx); terr != nil { + return internalServerError("Error confirming user").WithInternalError(terr) + } + } + + if user.IsAnonymous { + user.IsAnonymous = false + if terr := tx.UpdateOnly(user, "is_anonymous"); terr != nil { + return terr + } + } + + if terr := tx.Load(user, "Identities"); terr != nil { + return internalServerError("Error refetching identities").WithInternalError(terr) + } + return nil + }) + if err != nil { + return nil, err + } + return user, nil +} + +func (a *API) prepErrorRedirectURL(err *HTTPError, r *http.Request, rurl string, flowType models.FlowType) (string, error) { + u, perr := url.Parse(rurl) + if perr != nil { + return "", err + } + q := u.Query() + + // Maintain separate query params for hash and query + hq := url.Values{} + log := observability.GetLogEntry(r).Entry + errorID := utilities.GetRequestID(r.Context()) + err.ErrorID = errorID + log.WithError(err.Cause()).Info(err.Error()) + if str, ok := oauthErrorMap[err.HTTPStatus]; ok { + hq.Set("error", str) + q.Set("error", str) + } + hq.Set("error_code", err.ErrorCode) + hq.Set("error_description", err.Message) + + q.Set("error_code", err.ErrorCode) + q.Set("error_description", err.Message) + if flowType == models.PKCEFlow { + // Additionally, may override existing error query param if set to PKCE. + u.RawQuery = q.Encode() + } + // Left as hash fragment to comply with spec. + u.Fragment = hq.Encode() + return u.String(), nil +} + +func (a *API) prepRedirectURL(message string, rurl string, flowType models.FlowType) (string, error) { + u, perr := url.Parse(rurl) + if perr != nil { + return "", perr + } + hq := url.Values{} + q := u.Query() + hq.Set("message", message) + if flowType == models.PKCEFlow { + q.Set("message", message) + } + u.RawQuery = q.Encode() + u.Fragment = hq.Encode() + return u.String(), nil +} + +func (a *API) prepPKCERedirectURL(rurl, code string) (string, error) { + u, err := url.Parse(rurl) + if err != nil { + return "", err + } + q := u.Query() + q.Set("code", code) + u.RawQuery = q.Encode() + return u.String(), nil +} + +func (a *API) emailChangeVerify(r *http.Request, conn *storage.Connection, params *VerifyParams, user *models.User) (*models.User, error) { + config := a.config + if !config.Mailer.Autoconfirm && + config.Mailer.SecureEmailChangeEnabled && + user.EmailChangeConfirmStatus == zeroConfirmation && + user.GetEmail() != "" { + err := conn.Transaction(func(tx *storage.Connection) error { + currentOTT, terr := models.FindOneTimeToken(tx, params.TokenHash, models.EmailChangeTokenCurrent) + if terr != nil && !models.IsNotFoundError(terr) { + return terr + } + + newOTT, terr := models.FindOneTimeToken(tx, params.TokenHash, models.EmailChangeTokenNew) + if terr != nil && !models.IsNotFoundError(terr) { + return terr + } + + user.EmailChangeConfirmStatus = singleConfirmation + + if params.Token == user.EmailChangeTokenCurrent || params.TokenHash == user.EmailChangeTokenCurrent || (currentOTT != nil && params.TokenHash == currentOTT.TokenHash) { + user.EmailChangeTokenCurrent = "" + if terr := models.ClearOneTimeTokenForUser(tx, user.ID, models.EmailChangeTokenCurrent); terr != nil { + return terr + } + } else if params.Token == user.EmailChangeTokenNew || params.TokenHash == user.EmailChangeTokenNew || (newOTT != nil && params.TokenHash == newOTT.TokenHash) { + user.EmailChangeTokenNew = "" + if terr := models.ClearOneTimeTokenForUser(tx, user.ID, models.EmailChangeTokenNew); terr != nil { + return terr + } + } + if terr := tx.UpdateOnly(user, "email_change_confirm_status", "email_change_token_current", "email_change_token_new"); terr != nil { + return terr + } + return nil + }) + if err != nil { + return nil, err + } + return nil, nil + } + + // one email is confirmed at this point if GOTRUE_MAILER_SECURE_EMAIL_CHANGE_ENABLED is enabled + err := conn.Transaction(func(tx *storage.Connection) error { + if terr := models.NewAuditLogEntry(r, tx, user, models.UserModifiedAction, "", nil); terr != nil { + return terr + } + + if identity, terr := models.FindIdentityByIdAndProvider(tx, user.ID.String(), "email"); terr != nil { + if !models.IsNotFoundError(terr) { + return terr + } + // confirming the email change should create a new email identity if the user doesn't have one + if _, terr = a.createNewIdentity(tx, user, "email", structs.Map(provider.Claims{ + Subject: user.ID.String(), + Email: user.EmailChange, + EmailVerified: true, + })); terr != nil { + return terr + } + } else { + if terr := identity.UpdateIdentityData(tx, map[string]interface{}{ + "email": user.EmailChange, + "email_verified": true, + }); terr != nil { + return terr + } + } + if user.IsAnonymous { + user.IsAnonymous = false + if terr := tx.UpdateOnly(user, "is_anonymous"); terr != nil { + return terr + } + } + if terr := tx.Load(user, "Identities"); terr != nil { + return internalServerError("Error refetching identities").WithInternalError(terr) + } + if terr := user.ConfirmEmailChange(tx, zeroConfirmation); terr != nil { + return internalServerError("Error confirm email").WithInternalError(terr) + } + + return nil + }) + if err != nil { + return nil, err + } + + return user, nil +} + +func (a *API) verifyTokenHash(conn *storage.Connection, params *VerifyParams) (*models.User, error) { + config := a.config + + var user *models.User + var err error + switch params.Type { + case mail.EmailOTPVerification: + // need to find user by confirmation token or recovery token with the token hash + user, err = models.FindUserByConfirmationOrRecoveryToken(conn, params.TokenHash) + case mail.SignupVerification, mail.InviteVerification: + user, err = models.FindUserByConfirmationToken(conn, params.TokenHash) + case mail.RecoveryVerification, mail.MagicLinkVerification: + user, err = models.FindUserByRecoveryToken(conn, params.TokenHash) + case mail.EmailChangeVerification: + user, err = models.FindUserByEmailChangeToken(conn, params.TokenHash) + default: + return nil, badRequestError(ErrorCodeValidationFailed, "Invalid email verification type") + } + + if err != nil { + if models.IsNotFoundError(err) { + return nil, forbiddenError(ErrorCodeOTPExpired, "Email link is invalid or has expired").WithInternalError(err) + } + return nil, internalServerError("Database error finding user from email link").WithInternalError(err) + } + + if user.IsBanned() { + return nil, forbiddenError(ErrorCodeUserBanned, "User is banned") + } + + var isExpired bool + switch params.Type { + case mail.EmailOTPVerification: + sentAt := user.ConfirmationSentAt + params.Type = "signup" + if user.RecoveryToken == params.TokenHash { + sentAt = user.RecoverySentAt + params.Type = "magiclink" + } + isExpired = isOtpExpired(sentAt, config.Mailer.OtpExp) + case mail.SignupVerification, mail.InviteVerification: + isExpired = isOtpExpired(user.ConfirmationSentAt, config.Mailer.OtpExp) + case mail.RecoveryVerification, mail.MagicLinkVerification: + isExpired = isOtpExpired(user.RecoverySentAt, config.Mailer.OtpExp) + case mail.EmailChangeVerification: + isExpired = isOtpExpired(user.EmailChangeSentAt, config.Mailer.OtpExp) + } + + if isExpired { + return nil, forbiddenError(ErrorCodeOTPExpired, "Email link is invalid or has expired").WithInternalMessage("email link has expired") + } + + return user, nil +} + +// verifyUserAndToken verifies the token associated to the user based on the verify type +func (a *API) verifyUserAndToken(conn *storage.Connection, params *VerifyParams, aud string) (*models.User, error) { + config := a.config + + var user *models.User + var err error + tokenHash := params.TokenHash + + switch params.Type { + case phoneChangeVerification: + user, err = models.FindUserByPhoneChangeAndAudience(conn, params.Phone, aud) + case smsVerification: + user, err = models.FindUserByPhoneAndAudience(conn, params.Phone, aud) + case mail.EmailChangeVerification: + // Since the email change could be trigger via the implicit or PKCE flow, + // the query used has to also check if the token saved in the db contains the pkce_ prefix + user, err = models.FindUserForEmailChange(conn, params.Email, tokenHash, aud, config.Mailer.SecureEmailChangeEnabled) + default: + user, err = models.FindUserByEmailAndAudience(conn, params.Email, aud) + } + + if err != nil { + if models.IsNotFoundError(err) { + return nil, forbiddenError(ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) + } + return nil, internalServerError("Database error finding user").WithInternalError(err) + } + + if user.IsBanned() { + return nil, forbiddenError(ErrorCodeUserBanned, "User is banned") + } + + var isValid bool + + smsProvider, _ := sms_provider.GetSmsProvider(*config) + switch params.Type { + case mail.EmailOTPVerification: + // if the type is emailOTPVerification, we'll check both the confirmation_token and recovery_token columns + if isOtpValid(tokenHash, user.ConfirmationToken, user.ConfirmationSentAt, config.Mailer.OtpExp) { + isValid = true + params.Type = mail.SignupVerification + } else if isOtpValid(tokenHash, user.RecoveryToken, user.RecoverySentAt, config.Mailer.OtpExp) { + isValid = true + params.Type = mail.MagicLinkVerification + } else { + isValid = false + } + case mail.SignupVerification, mail.InviteVerification: + isValid = isOtpValid(tokenHash, user.ConfirmationToken, user.ConfirmationSentAt, config.Mailer.OtpExp) + case mail.RecoveryVerification, mail.MagicLinkVerification: + isValid = isOtpValid(tokenHash, user.RecoveryToken, user.RecoverySentAt, config.Mailer.OtpExp) + case mail.EmailChangeVerification: + isValid = isOtpValid(tokenHash, user.EmailChangeTokenCurrent, user.EmailChangeSentAt, config.Mailer.OtpExp) || + isOtpValid(tokenHash, user.EmailChangeTokenNew, user.EmailChangeSentAt, config.Mailer.OtpExp) + case phoneChangeVerification, smsVerification: + if testOTP, ok := config.Sms.GetTestOTP(params.Phone, time.Now()); ok { + if params.Token == testOTP { + return user, nil + } + } + + phone := params.Phone + sentAt := user.ConfirmationSentAt + expectedToken := user.ConfirmationToken + if params.Type == phoneChangeVerification { + phone = user.PhoneChange + sentAt = user.PhoneChangeSentAt + expectedToken = user.PhoneChangeToken + } + + if !config.Hook.SendSMS.Enabled && config.Sms.IsTwilioVerifyProvider() { + if err := smsProvider.(*sms_provider.TwilioVerifyProvider).VerifyOTP(phone, params.Token); err != nil { + return nil, forbiddenError(ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) + } + return user, nil + } + isValid = isOtpValid(tokenHash, expectedToken, sentAt, config.Sms.OtpExp) + } + + if !isValid { + return nil, forbiddenError(ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalMessage("token has expired or is invalid") + } + return user, nil +} + +// isOtpValid checks the actual otp sent against the expected otp and ensures that it's within the valid window +func isOtpValid(actual, expected string, sentAt *time.Time, otpExp uint) bool { + if expected == "" || sentAt == nil { + return false + } + return !isOtpExpired(sentAt, otpExp) && ((actual == expected) || ("pkce_"+actual == expected)) +} + +func isOtpExpired(sentAt *time.Time, otpExp uint) bool { + return time.Now().After(sentAt.Add(time.Second * time.Duration(otpExp))) // #nosec G115 +} + +// isPhoneOtpVerification checks if the verification came from a phone otp +func isPhoneOtpVerification(params *VerifyParams) bool { + return params.Phone != "" && params.Email == "" +} + +// isEmailOtpVerification checks if the verification came from an email otp +func isEmailOtpVerification(params *VerifyParams) bool { + return params.Phone == "" && params.Email != "" +} + +func isUsingTokenHash(params *VerifyParams) bool { + return params.TokenHash != "" && params.Token == "" && params.Phone == "" && params.Email == "" +} diff --git a/auth_v2.169.0/internal/api/verify_test.go b/auth_v2.169.0/internal/api/verify_test.go new file mode 100644 index 0000000..7c97d69 --- /dev/null +++ b/auth_v2.169.0/internal/api/verify_test.go @@ -0,0 +1,1280 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + mail "github.com/supabase/auth/internal/mailer" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/models" +) + +type VerifyTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestVerify(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &VerifyTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *VerifyTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + // Create user + u, err := models.NewUser("12345678", "test@example.com", "password", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") + + // Create identity + i, err := models.NewIdentity(u, "email", map[string]interface{}{ + "sub": u.ID.String(), + "email": "test@example.com", + "email_verified": false, + }) + require.NoError(ts.T(), err, "Error creating test identity model") + require.NoError(ts.T(), ts.API.db.Create(i), "Error saving new test identity") +} + +func (ts *VerifyTestSuite) TestVerifyPasswordRecovery() { + // modify config so we don't hit rate limit from requesting recovery twice in 60s + ts.Config.SMTP.MaxFrequency = 60 + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + u.RecoverySentAt = &time.Time{} + require.NoError(ts.T(), ts.API.db.Update(u)) + testEmail := "test@example.com" + + cases := []struct { + desc string + body map[string]interface{} + isPKCE bool + }{ + { + desc: "Implict Flow Recovery", + body: map[string]interface{}{ + "email": testEmail, + }, + isPKCE: false, + }, + { + desc: "PKCE Flow", + body: map[string]interface{}{ + "email": testEmail, + // Code Challenge needs to be at least 43 characters long + "code_challenge": "6b151854-cc15-4e29-8db7-3d3a9f15b3066b151854-cc15-4e29-8db7-3d3a9f15b306", + "code_challenge_method": models.SHA256.String(), + }, + isPKCE: true, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + // Reset user + u.EmailConfirmedAt = nil + require.NoError(ts.T(), ts.API.db.Update(u)) + require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) + + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/recover", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + assert.WithinDuration(ts.T(), time.Now(), *u.RecoverySentAt, 1*time.Second) + assert.False(ts.T(), u.IsConfirmed()) + + recoveryToken := u.RecoveryToken + + reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.RecoveryVerification, recoveryToken) + req = httptest.NewRequest(http.MethodGet, reqURL, nil) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusSeeOther, w.Code) + + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + assert.True(ts.T(), u.IsConfirmed()) + + if c.isPKCE { + rURL, _ := w.Result().Location() + + f, err := url.ParseQuery(rURL.RawQuery) + require.NoError(ts.T(), err) + assert.NotEmpty(ts.T(), f.Get("code")) + } + }) + } +} + +func (ts *VerifyTestSuite) TestVerifySecureEmailChange() { + currentEmail := "test@example.com" + newEmail := "new@example.com" + + // Change from new email to current email and back to new email + cases := []struct { + desc string + body map[string]interface{} + isPKCE bool + currentEmail string + newEmail string + }{ + { + desc: "Implict Flow Email Change", + body: map[string]interface{}{ + "email": newEmail, + }, + isPKCE: false, + currentEmail: currentEmail, + newEmail: newEmail, + }, + { + desc: "PKCE Email Change", + body: map[string]interface{}{ + "email": currentEmail, + // Code Challenge needs to be at least 43 characters long + "code_challenge": "6b151854-cc15-4e29-8db7-3d3a9f15b3066b151854-cc15-4e29-8db7-3d3a9f15b306", + "code_challenge_method": models.SHA256.String(), + }, + isPKCE: true, + currentEmail: newEmail, + newEmail: currentEmail, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + u, err := models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + // reset user + u.EmailChangeSentAt = nil + u.EmailChangeTokenCurrent = "" + u.EmailChangeTokenNew = "" + require.NoError(ts.T(), ts.API.db.Update(u)) + require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) + + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + + // Setup request + req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Generate access token for request and a mock session + var token string + session, err := models.NewSession(u.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(session)) + + token, _, err = ts.API.generateAccessToken(req, ts.API.db, u, &session.ID, models.MagicLink) + require.NoError(ts.T(), err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + currentTokenHash := u.EmailChangeTokenCurrent + newTokenHash := u.EmailChangeTokenNew + + u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + assert.WithinDuration(ts.T(), time.Now(), *u.EmailChangeSentAt, 1*time.Second) + assert.False(ts.T(), u.IsConfirmed()) + + // Verify new email + reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.EmailChangeVerification, newTokenHash) + req = httptest.NewRequest(http.MethodGet, reqURL, nil) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusSeeOther, w.Code) + urlVal, err := url.Parse(w.Result().Header.Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + var v url.Values + if !c.isPKCE { + v, err = url.ParseQuery(urlVal.Fragment) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("message")) + } else if c.isPKCE { + v, err = url.ParseQuery(urlVal.RawQuery) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("message")) + + v, err = url.ParseQuery(urlVal.Fragment) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("message")) + } + + u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), singleConfirmation, u.EmailChangeConfirmStatus) + + // Verify old email + reqURL = fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.EmailChangeVerification, currentTokenHash) + req = httptest.NewRequest(http.MethodGet, reqURL, nil) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusSeeOther, w.Code) + + urlVal, err = url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + if !c.isPKCE { + v, err = url.ParseQuery(urlVal.Fragment) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("access_token")) + ts.Require().NotEmpty(v.Get("expires_in")) + ts.Require().NotEmpty(v.Get("refresh_token")) + } else if c.isPKCE { + v, err = url.ParseQuery(urlVal.RawQuery) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("code")) + } + + // user's email should've been updated to newEmail + u, err = models.FindUserByEmailAndAudience(ts.API.db, c.newEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + require.Equal(ts.T(), zeroConfirmation, u.EmailChangeConfirmStatus) + + // Reset confirmation status after each test + u.EmailConfirmedAt = nil + require.NoError(ts.T(), ts.API.db.Update(u)) + }) + } +} + +func (ts *VerifyTestSuite) TestExpiredConfirmationToken() { + // verify variant testing not necessary in this test as it's testing + // the ConfirmationSentAt behavior, not the ConfirmationToken behavior + + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + u.ConfirmationToken = "asdf3" + sentTime := time.Now().Add(-48 * time.Hour) + u.ConfirmationSentAt = &sentTime + require.NoError(ts.T(), ts.API.db.Update(u)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken)) + + // Setup request + reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.SignupVerification, u.ConfirmationToken) + req := httptest.NewRequest(http.MethodGet, reqURL, nil) + + // Setup response recorder + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusSeeOther, w.Code) + + rurl, err := url.Parse(w.Header().Get("Location")) + require.NoError(ts.T(), err, "redirect url parse failed") + + f, err := url.ParseQuery(rurl.Fragment) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), ErrorCodeOTPExpired, f.Get("error_code")) + assert.Equal(ts.T(), "Email link is invalid or has expired", f.Get("error_description")) + assert.Equal(ts.T(), "access_denied", f.Get("error")) +} + +func (ts *VerifyTestSuite) TestInvalidOtp() { + u, err := models.FindUserByPhoneAndAudience(ts.API.db, "12345678", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + sentTime := time.Now().Add(-48 * time.Hour) + u.ConfirmationToken = "123456" + u.ConfirmationSentAt = &sentTime + u.PhoneChange = "22222222" + u.PhoneChangeToken = "123456" + u.PhoneChangeSentAt = &sentTime + u.EmailChange = "test@gmail.com" + u.EmailChangeTokenNew = "123456" + u.EmailChangeTokenCurrent = "123456" + require.NoError(ts.T(), ts.API.db.Update(u)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.PhoneChange, u.PhoneChangeToken, models.PhoneChangeToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.EmailChangeTokenCurrent, models.EmailChangeTokenCurrent)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.EmailChange, u.EmailChangeTokenNew, models.EmailChangeTokenNew)) + + type ResponseBody struct { + Code int `json:"code"` + Msg string `json:"msg"` + } + + expectedResponse := ResponseBody{ + Code: http.StatusForbidden, + Msg: "Token has expired or is invalid", + } + + cases := []struct { + desc string + sentTime time.Time + body map[string]interface{} + expected ResponseBody + }{ + { + desc: "Expired SMS OTP", + sentTime: time.Now().Add(-48 * time.Hour), + body: map[string]interface{}{ + "type": smsVerification, + "token": u.ConfirmationToken, + "phone": u.GetPhone(), + }, + expected: expectedResponse, + }, + { + desc: "Invalid SMS OTP", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": smsVerification, + "token": "invalid_otp", + "phone": u.GetPhone(), + }, + expected: expectedResponse, + }, + { + desc: "Invalid Phone Change OTP", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": phoneChangeVerification, + "token": "invalid_otp", + "phone": u.PhoneChange, + }, + expected: expectedResponse, + }, + { + desc: "Invalid Email OTP", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": mail.SignupVerification, + "token": "invalid_otp", + "email": u.GetEmail(), + }, + expected: expectedResponse, + }, + { + desc: "Invalid Email Change", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": mail.EmailChangeVerification, + "token": "invalid_otp", + "email": u.GetEmail(), + }, + expected: expectedResponse, + }, + } + + for _, caseItem := range cases { + c := caseItem + + ts.Run(c.desc, func() { + // update token sent time + sentTime = time.Now() + u.ConfirmationSentAt = &c.sentTime + require.NoError(ts.T(), ts.API.db.Update(u)) + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + b, err := io.ReadAll(w.Body) + require.NoError(ts.T(), err) + var resp ResponseBody + err = json.Unmarshal(b, &resp) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), c.expected.Code, resp.Code) + assert.Equal(ts.T(), c.expected.Msg, resp.Msg) + + }) + } +} + +func (ts *VerifyTestSuite) TestExpiredRecoveryToken() { + // verify variant testing not necessary in this test as it's testing + // the RecoverySentAt behavior, not the RecoveryToken behavior + + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + u.RecoveryToken = "asdf3" + sentTime := time.Now().Add(-48 * time.Hour) + u.RecoverySentAt = &sentTime + require.NoError(ts.T(), ts.API.db.Update(u)) + + // Setup request + reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", "signup", u.RecoveryToken) + req := httptest.NewRequest(http.MethodGet, reqURL, nil) + + // Setup response recorder + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + assert.Equal(ts.T(), http.StatusSeeOther, w.Code, w.Body.String()) +} + +func (ts *VerifyTestSuite) TestVerifyPermitedCustomUri() { + // verify variant testing not necessary in this test as it's testing + // the redirect URL behavior, not the RecoveryToken behavior + + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + u.RecoverySentAt = &time.Time{} + require.NoError(ts.T(), ts.API.db.Update(u)) + + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + })) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/recover", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + assert.WithinDuration(ts.T(), time.Now(), *u.RecoverySentAt, 1*time.Second) + assert.False(ts.T(), u.IsConfirmed()) + + redirectURL, _ := url.Parse(ts.Config.URIAllowList[0]) + + reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s&redirect_to=%s", "recovery", u.RecoveryToken, redirectURL.String()) + req = httptest.NewRequest(http.MethodGet, reqURL, nil) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusSeeOther, w.Code) + rURL, _ := w.Result().Location() + assert.Equal(ts.T(), redirectURL.Hostname(), rURL.Hostname()) + + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + assert.True(ts.T(), u.IsConfirmed()) +} + +func (ts *VerifyTestSuite) TestVerifyNotPermitedCustomUri() { + // verify variant testing not necessary in this test as it's testing + // the redirect URL behavior, not the RecoveryToken behavior + + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + u.RecoverySentAt = &time.Time{} + require.NoError(ts.T(), ts.API.db.Update(u)) + + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "test@example.com", + })) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/recover", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + assert.WithinDuration(ts.T(), time.Now(), *u.RecoverySentAt, 1*time.Second) + assert.False(ts.T(), u.IsConfirmed()) + + fakeredirectURL, _ := url.Parse("http://custom-url.com") + siteURL, _ := url.Parse(ts.Config.SiteURL) + + reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s&redirect_to=%s", "recovery", u.RecoveryToken, fakeredirectURL.String()) + req = httptest.NewRequest(http.MethodGet, reqURL, nil) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusSeeOther, w.Code) + rURL, _ := w.Result().Location() + assert.Equal(ts.T(), siteURL.Hostname(), rURL.Hostname()) + + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + assert.True(ts.T(), u.IsConfirmed()) +} + +func (ts *VerifyTestSuite) TestVerifySignupWithRedirectURLContainedPath() { + // verify variant testing not necessary in this test as it's testing + // the redirect URL behavior, not the RecoveryToken behavior + + testCases := []struct { + desc string + siteURL string + uriAllowList []string + requestredirectURL string + expectedredirectURL string + }{ + { + desc: "same site url and redirect url with path", + siteURL: "http://localhost:3000/#/", + uriAllowList: []string{"http://localhost:3000"}, + requestredirectURL: "http://localhost:3000/#/", + expectedredirectURL: "http://localhost:3000/#/", + }, + { + desc: "different site url and redirect url in allow list", + siteURL: "https://someapp-something.codemagic.app/#/", + uriAllowList: []string{"http://localhost:3000"}, + requestredirectURL: "http://localhost:3000", + expectedredirectURL: "http://localhost:3000", + }, + { + desc: "different site url and redirect url not in allow list", + siteURL: "https://someapp-something.codemagic.app/#/", + uriAllowList: []string{"http://localhost:3000"}, + requestredirectURL: "http://localhost:3000/docs", + expectedredirectURL: "https://someapp-something.codemagic.app/#/", + }, + { + desc: "same wildcard site url and redirect url in allow list", + siteURL: "http://sub.test.dev:3000/#/", + uriAllowList: []string{"http://*.test.dev:3000"}, + requestredirectURL: "http://sub.test.dev:3000/#/", + expectedredirectURL: "http://sub.test.dev:3000/#/", + }, + { + desc: "different wildcard site url and redirect url in allow list", + siteURL: "http://sub.test.dev/#/", + uriAllowList: []string{"http://*.other.dev:3000"}, + requestredirectURL: "http://sub.other.dev:3000", + expectedredirectURL: "http://sub.other.dev:3000", + }, + { + desc: "different wildcard site url and redirect url not in allow list", + siteURL: "http://test.dev:3000/#/", + uriAllowList: []string{"http://*.allowed.dev:3000"}, + requestredirectURL: "http://sub.test.dev:3000/#/", + expectedredirectURL: "http://test.dev:3000/#/", + }, + { + desc: "exact mobile deep link redirect url in allow list", + siteURL: "http://test.dev:3000/#/", + uriAllowList: []string{"twitter://timeline"}, + requestredirectURL: "twitter://timeline", + expectedredirectURL: "twitter://timeline", + }, + // previously the below example was not allowed and with good + // reason, however users do want flexibility in the redirect + // URL after the scheme, which is why the example is now corrected + { + desc: "wildcard mobile deep link redirect url in allow list", + siteURL: "http://test.dev:3000/#/", + uriAllowList: []string{"com.example.app://**"}, + requestredirectURL: "com.example.app://sign-in/v2", + expectedredirectURL: "com.example.app://sign-in/v2", + }, + { + desc: "redirect respects . separator", + siteURL: "http://localhost:3000", + uriAllowList: []string{"http://*.*.dev:3000"}, + requestredirectURL: "http://foo.bar.dev:3000", + expectedredirectURL: "http://foo.bar.dev:3000", + }, + { + desc: "redirect does not respect . separator", + siteURL: "http://localhost:3000", + uriAllowList: []string{"http://*.dev:3000"}, + requestredirectURL: "http://foo.bar.dev:3000", + expectedredirectURL: "http://localhost:3000", + }, + { + desc: "redirect respects / separator in url subdirectory", + siteURL: "http://localhost:3000", + uriAllowList: []string{"http://test.dev:3000/*/*"}, + requestredirectURL: "http://test.dev:3000/bar/foo", + expectedredirectURL: "http://test.dev:3000/bar/foo", + }, + { + desc: "redirect does not respect / separator in url subdirectory", + siteURL: "http://localhost:3000", + uriAllowList: []string{"http://test.dev:3000/*"}, + requestredirectURL: "http://test.dev:3000/bar/foo", + expectedredirectURL: "http://localhost:3000", + }, + } + + for _, tC := range testCases { + ts.Run(tC.desc, func() { + // prepare test data + ts.Config.SiteURL = tC.siteURL + redirectURL := tC.requestredirectURL + ts.Config.URIAllowList = tC.uriAllowList + ts.Config.ApplyDefaults() + + // set verify token to user as it actual do in magic link method + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + u.ConfirmationToken = "someToken" + sendTime := time.Now().Add(time.Hour) + u.ConfirmationSentAt = &sendTime + require.NoError(ts.T(), ts.API.db.Update(u)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken)) + + reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s&redirect_to=%s", "signup", u.ConfirmationToken, redirectURL) + req := httptest.NewRequest(http.MethodGet, reqURL, nil) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusSeeOther, w.Code) + rURL, _ := w.Result().Location() + assert.Contains(ts.T(), rURL.String(), tC.expectedredirectURL) // redirected url starts with per test value + + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + assert.True(ts.T(), u.IsConfirmed()) + assert.True(ts.T(), u.UserMetaData["email_verified"].(bool)) + assert.True(ts.T(), u.Identities[0].IdentityData["email_verified"].(bool)) + }) + } +} + +func (ts *VerifyTestSuite) TestVerifyPKCEOTP() { + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + t := time.Now() + u.ConfirmationSentAt = &t + u.RecoverySentAt = &t + u.EmailChangeSentAt = &t + require.NoError(ts.T(), ts.API.db.Update(u)) + + cases := []struct { + desc string + payload *VerifyParams + authenticationMethod models.AuthenticationMethod + }{ + { + desc: "Verify user on signup", + payload: &VerifyParams{ + Type: "signup", + Token: "pkce_confirmation_token", + }, + authenticationMethod: models.EmailSignup, + }, + { + desc: "Verify magiclink", + payload: &VerifyParams{ + Type: "magiclink", + Token: "pkce_recovery_token", + }, + authenticationMethod: models.MagicLink, + }, + } + for _, c := range cases { + ts.Run(c.desc, func() { + var buffer bytes.Buffer + // since the test user is the same, the tokens are being cleared after each successful verification attempt + // so we create them on each run + if c.payload.Type == "signup" { + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), c.payload.Token, models.ConfirmationToken)) + } else if c.payload.Type == "magiclink" { + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), c.payload.Token, models.RecoveryToken)) + } + + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.payload)) + codeChallenge := "codechallengecodechallengcodechallengcodechallengcodechallenge" + flowState := models.NewFlowState(c.authenticationMethod.String(), codeChallenge, models.SHA256, c.authenticationMethod, &u.ID) + require.NoError(ts.T(), ts.API.db.Create(flowState)) + + requestUrl := fmt.Sprintf("http://localhost/verify?type=%v&token=%v", c.payload.Type, c.payload.Token) + req := httptest.NewRequest(http.MethodGet, requestUrl, &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusSeeOther, w.Code) + rURL, _ := w.Result().Location() + + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + assert.True(ts.T(), u.IsConfirmed()) + + f, err := url.ParseQuery(rURL.RawQuery) + require.NoError(ts.T(), err) + assert.NotEmpty(ts.T(), f.Get("code")) + }) + } + +} + +func (ts *VerifyTestSuite) TestVerifyBannedUser() { + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + u.ConfirmationToken = "confirmation_token" + u.RecoveryToken = "recovery_token" + u.EmailChangeTokenCurrent = "current_email_change_token" + u.EmailChangeTokenNew = "new_email_change_token" + t := time.Now() + u.ConfirmationSentAt = &t + u.RecoverySentAt = &t + u.EmailChangeSentAt = &t + + t = time.Now().Add(24 * time.Hour) + u.BannedUntil = &t + require.NoError(ts.T(), ts.API.db.Update(u)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.RecoveryToken, models.RecoveryToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.EmailChangeTokenCurrent, models.EmailChangeTokenCurrent)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.EmailChangeTokenNew, models.EmailChangeTokenNew)) + + cases := []struct { + desc string + payload *VerifyParams + }{ + { + desc: "Verify banned user on signup", + payload: &VerifyParams{ + Type: "signup", + Token: u.ConfirmationToken, + }, + }, + { + desc: "Verify banned user on invite", + payload: &VerifyParams{ + Type: "invite", + Token: u.ConfirmationToken, + }, + }, + { + desc: "Verify banned user on recover", + payload: &VerifyParams{ + Type: "recovery", + Token: u.RecoveryToken, + }, + }, + { + desc: "Verify banned user on magiclink", + payload: &VerifyParams{ + Type: "magiclink", + Token: u.RecoveryToken, + }, + }, + { + desc: "Verify banned user on email change", + payload: &VerifyParams{ + Type: "email_change", + Token: u.EmailChangeTokenCurrent, + }, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.payload)) + + requestUrl := fmt.Sprintf("http://localhost/verify?type=%v&token=%v", c.payload.Type, c.payload.Token) + req := httptest.NewRequest(http.MethodGet, requestUrl, &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusSeeOther, w.Code) + + rurl, err := url.Parse(w.Header().Get("Location")) + require.NoError(ts.T(), err, "redirect url parse failed") + + f, err := url.ParseQuery(rurl.Fragment) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), ErrorCodeUserBanned, f.Get("error_code")) + }) + } +} + +func (ts *VerifyTestSuite) TestVerifyValidOtp() { + ts.Config.Mailer.SecureEmailChangeEnabled = true + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + u.EmailChange = "new@example.com" + u.Phone = "12345678" + u.PhoneChange = "1234567890" + require.NoError(ts.T(), ts.API.db.Update(u)) + + type expected struct { + code int + tokenHash string + } + + cases := []struct { + desc string + sentTime time.Time + body map[string]interface{} + expected + }{ + { + desc: "Valid SMS OTP", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": smsVerification, + "token": "123456", + "phone": u.GetPhone(), + }, + expected: expected{ + code: http.StatusOK, + tokenHash: crypto.GenerateTokenHash(u.GetPhone(), "123456"), + }, + }, + { + desc: "Valid Confirmation OTP", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": mail.SignupVerification, + "token": "123456", + "email": u.GetEmail(), + }, + expected: expected{ + code: http.StatusOK, + tokenHash: crypto.GenerateTokenHash(u.GetEmail(), "123456"), + }, + }, + { + desc: "Valid Signup Token Hash", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": mail.SignupVerification, + "token_hash": crypto.GenerateTokenHash(u.GetEmail(), "123456"), + }, + expected: expected{ + code: http.StatusOK, + tokenHash: crypto.GenerateTokenHash(u.GetEmail(), "123456"), + }, + }, + { + desc: "Valid Recovery OTP", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": mail.RecoveryVerification, + "token": "123456", + "email": u.GetEmail(), + }, + expected: expected{ + code: http.StatusOK, + tokenHash: crypto.GenerateTokenHash(u.GetEmail(), "123456"), + }, + }, + { + desc: "Valid Email OTP", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": mail.EmailOTPVerification, + "token": "123456", + "email": u.GetEmail(), + }, + expected: expected{ + code: http.StatusOK, + tokenHash: crypto.GenerateTokenHash(u.GetEmail(), "123456"), + }, + }, + { + desc: "Valid Email OTP (email casing shouldn't matter)", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": mail.EmailOTPVerification, + "token": "123456", + "email": strings.ToUpper(u.GetEmail()), + }, + expected: expected{ + code: http.StatusOK, + tokenHash: crypto.GenerateTokenHash(u.GetEmail(), "123456"), + }, + }, + { + desc: "Valid Email Change OTP", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": mail.EmailChangeVerification, + "token": "123456", + "email": u.EmailChange, + }, + expected: expected{ + code: http.StatusOK, + tokenHash: crypto.GenerateTokenHash(u.EmailChange, "123456"), + }, + }, + { + desc: "Valid Phone Change OTP", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": phoneChangeVerification, + "token": "123456", + "phone": u.PhoneChange, + }, + expected: expected{ + code: http.StatusOK, + tokenHash: crypto.GenerateTokenHash(u.PhoneChange, "123456"), + }, + }, + { + desc: "Valid Email Change Token Hash", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": mail.EmailChangeVerification, + "token_hash": crypto.GenerateTokenHash(u.EmailChange, "123456"), + }, + expected: expected{ + code: http.StatusOK, + tokenHash: crypto.GenerateTokenHash(u.EmailChange, "123456"), + }, + }, + { + desc: "Valid Email Verification Type", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": mail.EmailOTPVerification, + "token_hash": crypto.GenerateTokenHash(u.GetEmail(), "123456"), + }, + expected: expected{ + code: http.StatusOK, + tokenHash: crypto.GenerateTokenHash(u.GetEmail(), "123456"), + }, + }, + } + + for _, caseItem := range cases { + c := caseItem + ts.Run(c.desc, func() { + // create user + require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) + + u.ConfirmationSentAt = &c.sentTime + u.RecoverySentAt = &c.sentTime + u.EmailChangeSentAt = &c.sentTime + u.PhoneChangeSentAt = &c.sentTime + + u.ConfirmationToken = c.expected.tokenHash + u.RecoveryToken = c.expected.tokenHash + u.EmailChangeTokenNew = c.expected.tokenHash + u.PhoneChangeToken = c.expected.tokenHash + + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", u.ConfirmationToken, models.ConfirmationToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", u.RecoveryToken, models.RecoveryToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", u.EmailChangeTokenNew, models.EmailChangeTokenNew)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", u.PhoneChangeToken, models.PhoneChangeToken)) + + require.NoError(ts.T(), ts.API.db.Update(u)) + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), c.expected.code, w.Code) + }) + } +} + +func (ts *VerifyTestSuite) TestSecureEmailChangeWithTokenHash() { + ts.Config.Mailer.SecureEmailChangeEnabled = true + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + u.EmailChange = "new@example.com" + require.NoError(ts.T(), ts.API.db.Update(u)) + + currentEmailChangeToken := crypto.GenerateTokenHash(string(u.Email), "123456") + newEmailChangeToken := crypto.GenerateTokenHash(u.EmailChange, "123456") + + cases := []struct { + desc string + firstVerificationBody map[string]interface{} + secondVerificationBody map[string]interface{} + expectedStatus int + }{ + { + desc: "Secure Email Change with Token Hash (Success)", + firstVerificationBody: map[string]interface{}{ + "type": mail.EmailChangeVerification, + "token_hash": currentEmailChangeToken, + }, + secondVerificationBody: map[string]interface{}{ + "type": mail.EmailChangeVerification, + "token_hash": newEmailChangeToken, + }, + expectedStatus: http.StatusOK, + }, + { + desc: "Secure Email Change with Token Hash. Reusing a token hash twice should fail", + firstVerificationBody: map[string]interface{}{ + "type": mail.EmailChangeVerification, + "token_hash": currentEmailChangeToken, + }, + secondVerificationBody: map[string]interface{}{ + "type": mail.EmailChangeVerification, + "token_hash": currentEmailChangeToken, + }, + expectedStatus: http.StatusForbidden, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + // Set the corresponding email change tokens + u.EmailChangeTokenCurrent = currentEmailChangeToken + u.EmailChangeTokenNew = newEmailChangeToken + require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) + + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", currentEmailChangeToken, models.EmailChangeTokenCurrent)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", newEmailChangeToken, models.EmailChangeTokenNew)) + + currentTime := time.Now() + u.EmailChangeSentAt = ¤tTime + require.NoError(ts.T(), ts.API.db.Update(u)) + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.firstVerificationBody)) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.secondVerificationBody)) + + // Setup second request + req = httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup second response recorder + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), c.expectedStatus, w.Code) + }) + } +} + +func (ts *VerifyTestSuite) TestPrepRedirectURL() { + escapedMessage := url.QueryEscape(singleConfirmationAccepted) + cases := []struct { + desc string + message string + rurl string + flowType models.FlowType + expected string + }{ + { + desc: "(PKCE): Redirect URL with additional query params", + message: singleConfirmationAccepted, + rurl: "https://example.com/?first=another&second=other", + flowType: models.PKCEFlow, + expected: fmt.Sprintf("https://example.com/?first=another&message=%s&second=other#message=%s", escapedMessage, escapedMessage), + }, + { + desc: "(PKCE): Query params in redirect url are overriden", + message: singleConfirmationAccepted, + rurl: "https://example.com/?message=Valid+redirect+URL", + flowType: models.PKCEFlow, + expected: fmt.Sprintf("https://example.com/?message=%s#message=%s", escapedMessage, escapedMessage), + }, + { + desc: "(Implicit): plain redirect url", + message: singleConfirmationAccepted, + rurl: "https://example.com/", + flowType: models.ImplicitFlow, + expected: fmt.Sprintf("https://example.com/#message=%s", escapedMessage), + }, + { + desc: "(Implicit): query params retained", + message: singleConfirmationAccepted, + rurl: "https://example.com/?first=another", + flowType: models.ImplicitFlow, + expected: fmt.Sprintf("https://example.com/?first=another#message=%s", escapedMessage), + }, + } + for _, c := range cases { + ts.Run(c.desc, func() { + rurl, err := ts.API.prepRedirectURL(c.message, c.rurl, c.flowType) + require.NoError(ts.T(), err) + require.Equal(ts.T(), c.expected, rurl) + }) + } +} + +func (ts *VerifyTestSuite) TestPrepErrorRedirectURL() { + const DefaultError = "Invalid redirect URL" + redirectError := fmt.Sprintf("error=invalid_request&error_code=validation_failed&error_description=%s", url.QueryEscape(DefaultError)) + + cases := []struct { + desc string + message string + rurl string + flowType models.FlowType + expected string + }{ + { + desc: "(PKCE): Error in both query params and hash fragment", + message: "Valid redirect URL", + rurl: "https://example.com/", + flowType: models.PKCEFlow, + expected: fmt.Sprintf("https://example.com/?%s#%s", redirectError, redirectError), + }, + { + desc: "(PKCE): Error with conflicting query params in redirect url", + message: DefaultError, + rurl: "https://example.com/?error=Error+to+be+overriden", + flowType: models.PKCEFlow, + expected: fmt.Sprintf("https://example.com/?%s#%s", redirectError, redirectError), + }, + { + desc: "(Implicit): plain redirect url", + message: DefaultError, + rurl: "https://example.com/", + flowType: models.ImplicitFlow, + expected: fmt.Sprintf("https://example.com/#%s", redirectError), + }, + { + desc: "(Implicit): query params preserved", + message: DefaultError, + rurl: "https://example.com/?test=param", + flowType: models.ImplicitFlow, + expected: fmt.Sprintf("https://example.com/?test=param#%s", redirectError), + }, + } + for _, c := range cases { + ts.Run(c.desc, func() { + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + rurl, err := ts.API.prepErrorRedirectURL(badRequestError(ErrorCodeValidationFailed, DefaultError), req, c.rurl, c.flowType) + require.NoError(ts.T(), err) + require.Equal(ts.T(), c.expected, rurl) + }) + } +} + +func (ts *VerifyTestSuite) TestVerifyValidateParams() { + cases := []struct { + desc string + params *VerifyParams + method string + expected error + }{ + { + desc: "Successful GET Verify", + params: &VerifyParams{ + Type: "signup", + Token: "some-token-hash", + }, + method: http.MethodGet, + expected: nil, + }, + { + desc: "Successful POST Verify (TokenHash)", + params: &VerifyParams{ + Type: "signup", + TokenHash: "some-token-hash", + }, + method: http.MethodPost, + expected: nil, + }, + { + desc: "Successful POST Verify (Token)", + params: &VerifyParams{ + Type: "signup", + Token: "some-token", + Email: "email@example.com", + }, + method: http.MethodPost, + expected: nil, + }, + // unsuccessful validations + { + desc: "Need to send email or phone number with token", + params: &VerifyParams{ + Type: "signup", + Token: "some-token", + }, + method: http.MethodPost, + expected: badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided on verify"), + }, + { + desc: "Cannot send both TokenHash and Token", + params: &VerifyParams{ + Type: "signup", + Token: "some-token", + TokenHash: "some-token-hash", + }, + method: http.MethodPost, + expected: badRequestError(ErrorCodeValidationFailed, "Verify requires either a token or a token hash"), + }, + { + desc: "No verification type specified", + params: &VerifyParams{ + Token: "some-token", + Email: "email@example.com", + }, + method: http.MethodPost, + expected: badRequestError(ErrorCodeValidationFailed, "Verify requires a verification type"), + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + req := httptest.NewRequest(c.method, "http://localhost", nil) + err := c.params.Validate(req, ts.API) + require.Equal(ts.T(), c.expected, err) + }) + } +} diff --git a/auth_v2.169.0/internal/conf/configuration.go b/auth_v2.169.0/internal/conf/configuration.go new file mode 100644 index 0000000..c4d910d --- /dev/null +++ b/auth_v2.169.0/internal/conf/configuration.go @@ -0,0 +1,1144 @@ +package conf + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "net/url" + "os" + "path/filepath" + "regexp" + "strings" + "text/template" + "time" + + "github.com/gobwas/glob" + "github.com/golang-jwt/jwt/v5" + "github.com/joho/godotenv" + "github.com/kelseyhightower/envconfig" + "github.com/lestrrat-go/jwx/v2/jwk" + "gopkg.in/gomail.v2" +) + +const defaultMinPasswordLength int = 6 +const defaultChallengeExpiryDuration float64 = 300 +const defaultFactorExpiryDuration time.Duration = 300 * time.Second +const defaultFlowStateExpiryDuration time.Duration = 300 * time.Second + +// See: https://www.postgresql.org/docs/7.0/syntax525.htm +var postgresNamesRegexp = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]{0,62}$`) + +// See: https://github.com/standard-webhooks/standard-webhooks/blob/main/spec/standard-webhooks.md +// We use 4 * Math.ceil(n/3) to obtain unpadded length in base 64 +// So this 4 * Math.ceil(24/3) = 32 and 4 * Math.ceil(64/3) = 88 for symmetric secrets +// Since Ed25519 key is 32 bytes so we have 4 * Math.ceil(32/3) = 44 +var symmetricSecretFormat = regexp.MustCompile(`^v1,whsec_[A-Za-z0-9+/=]{32,88}`) +var asymmetricSecretFormat = regexp.MustCompile(`^v1a,whpk_[A-Za-z0-9+/=]{44,}:whsk_[A-Za-z0-9+/=]{44,}$`) + +// Time is used to represent timestamps in the configuration, as envconfig has +// trouble parsing empty strings, due to time.Time.UnmarshalText(). +type Time struct { + time.Time +} + +func (t *Time) UnmarshalText(text []byte) error { + trimed := bytes.TrimSpace(text) + + if len(trimed) < 1 { + t.Time = time.Time{} + } else { + if err := t.Time.UnmarshalText(trimed); err != nil { + return err + } + } + + return nil +} + +// OAuthProviderConfiguration holds all config related to external account providers. +type OAuthProviderConfiguration struct { + ClientID []string `json:"client_id" split_words:"true"` + Secret string `json:"secret"` + RedirectURI string `json:"redirect_uri" split_words:"true"` + URL string `json:"url"` + ApiURL string `json:"api_url" split_words:"true"` + Enabled bool `json:"enabled"` + SkipNonceCheck bool `json:"skip_nonce_check" split_words:"true"` +} + +type AnonymousProviderConfiguration struct { + Enabled bool `json:"enabled" default:"false"` +} + +type EmailProviderConfiguration struct { + Enabled bool `json:"enabled" default:"true"` + + AuthorizedAddresses []string `json:"authorized_addresses" split_words:"true"` + + MagicLinkEnabled bool `json:"magic_link_enabled" default:"true" split_words:"true"` +} + +// DBConfiguration holds all the database related configuration. +type DBConfiguration struct { + Driver string `json:"driver" required:"true"` + URL string `json:"url" envconfig:"DATABASE_URL" required:"true"` + Namespace string `json:"namespace" envconfig:"DB_NAMESPACE" default:"auth"` + // MaxPoolSize defaults to 0 (unlimited). + MaxPoolSize int `json:"max_pool_size" split_words:"true"` + MaxIdlePoolSize int `json:"max_idle_pool_size" split_words:"true"` + ConnMaxLifetime time.Duration `json:"conn_max_lifetime,omitempty" split_words:"true"` + ConnMaxIdleTime time.Duration `json:"conn_max_idle_time,omitempty" split_words:"true"` + HealthCheckPeriod time.Duration `json:"health_check_period" split_words:"true"` + MigrationsPath string `json:"migrations_path" split_words:"true" default:"./migrations"` + CleanupEnabled bool `json:"cleanup_enabled" split_words:"true" default:"false"` +} + +func (c *DBConfiguration) Validate() error { + return nil +} + +// JWTConfiguration holds all the JWT related configuration. +type JWTConfiguration struct { + Secret string `json:"secret" required:"true"` + Exp int `json:"exp"` + Aud string `json:"aud"` + AdminGroupName string `json:"admin_group_name" split_words:"true"` + AdminRoles []string `json:"admin_roles" split_words:"true"` + DefaultGroupName string `json:"default_group_name" split_words:"true"` + Issuer string `json:"issuer"` + KeyID string `json:"key_id" split_words:"true"` + Keys JwtKeysDecoder `json:"keys"` + ValidMethods []string `json:"-"` +} + +type MFAFactorTypeConfiguration struct { + EnrollEnabled bool `json:"enroll_enabled" split_words:"true" default:"false"` + VerifyEnabled bool `json:"verify_enabled" split_words:"true" default:"false"` +} + +type TOTPFactorTypeConfiguration struct { + EnrollEnabled bool `json:"enroll_enabled" split_words:"true" default:"true"` + VerifyEnabled bool `json:"verify_enabled" split_words:"true" default:"true"` +} + +type PhoneFactorTypeConfiguration struct { + // Default to false in order to ensure Phone MFA is opt-in + MFAFactorTypeConfiguration + OtpLength int `json:"otp_length" split_words:"true"` + SMSTemplate *template.Template `json:"-"` + MaxFrequency time.Duration `json:"max_frequency" split_words:"true"` + Template string `json:"template"` +} + +// MFAConfiguration holds all the MFA related Configuration +type MFAConfiguration struct { + ChallengeExpiryDuration float64 `json:"challenge_expiry_duration" default:"300" split_words:"true"` + FactorExpiryDuration time.Duration `json:"factor_expiry_duration" default:"300s" split_words:"true"` + RateLimitChallengeAndVerify float64 `split_words:"true" default:"15"` + MaxEnrolledFactors float64 `split_words:"true" default:"10"` + MaxVerifiedFactors int `split_words:"true" default:"10"` + Phone PhoneFactorTypeConfiguration `split_words:"true"` + TOTP TOTPFactorTypeConfiguration `split_words:"true"` + WebAuthn MFAFactorTypeConfiguration `split_words:"true"` +} + +type APIConfiguration struct { + Host string + Port string `envconfig:"PORT" default:"8081"` + Endpoint string + RequestIDHeader string `envconfig:"REQUEST_ID_HEADER"` + ExternalURL string `json:"external_url" envconfig:"API_EXTERNAL_URL" required:"true"` + MaxRequestDuration time.Duration `json:"max_request_duration" split_words:"true" default:"10s"` +} + +func (a *APIConfiguration) Validate() error { + _, err := url.ParseRequestURI(a.ExternalURL) + if err != nil { + return err + } + + return nil +} + +type SessionsConfiguration struct { + Timebox *time.Duration `json:"timebox"` + InactivityTimeout *time.Duration `json:"inactivity_timeout,omitempty" split_words:"true"` + + SinglePerUser bool `json:"single_per_user" split_words:"true"` + Tags []string `json:"tags,omitempty"` +} + +func (c *SessionsConfiguration) Validate() error { + if c.Timebox == nil { + return nil + } + + if *c.Timebox <= time.Duration(0) { + return fmt.Errorf("conf: session timebox duration must be positive when set, was %v", (*c.Timebox).String()) + } + + return nil +} + +type PasswordRequiredCharacters []string + +func (v *PasswordRequiredCharacters) Decode(value string) error { + parts := strings.Split(value, ":") + + for i := 0; i < len(parts)-1; i += 1 { + part := parts[i] + + if part == "" { + continue + } + + // part ended in escape character, so it should be joined with the next one + if part[len(part)-1] == '\\' { + parts[i] = part[0:len(part)-1] + ":" + parts[i+1] + parts[i+1] = "" + continue + } + } + + for _, part := range parts { + if part != "" { + *v = append(*v, part) + } + } + + return nil +} + +// HIBPBloomConfiguration configures a bloom cache for pwned passwords. Use +// this tool to gauge the Items and FalsePositives values: +// https://hur.st/bloomfilter +type HIBPBloomConfiguration struct { + Enabled bool `json:"enabled"` + Items uint `json:"items" default:"100000"` + FalsePositives float64 `json:"false_positives" split_words:"true" default:"0.0000099"` +} + +type HIBPConfiguration struct { + Enabled bool `json:"enabled"` + FailClosed bool `json:"fail_closed" split_words:"true"` + + UserAgent string `json:"user_agent" split_words:"true" default:"https://github.com/supabase/gotrue"` + + Bloom HIBPBloomConfiguration `json:"bloom"` +} + +type PasswordConfiguration struct { + MinLength int `json:"min_length" split_words:"true"` + + RequiredCharacters PasswordRequiredCharacters `json:"required_characters" split_words:"true"` + + HIBP HIBPConfiguration `json:"hibp"` +} + +// GlobalConfiguration holds all the configuration that applies to all instances. +type GlobalConfiguration struct { + API APIConfiguration + DB DBConfiguration + External ProviderConfiguration + Logging LoggingConfig `envconfig:"LOG"` + Profiler ProfilerConfig `envconfig:"PROFILER"` + OperatorToken string `split_words:"true" required:"false"` + Tracing TracingConfig + Metrics MetricsConfig + SMTP SMTPConfiguration + + RateLimitHeader string `split_words:"true"` + RateLimitEmailSent Rate `split_words:"true" default:"30"` + RateLimitSmsSent Rate `split_words:"true" default:"30"` + RateLimitVerify float64 `split_words:"true" default:"30"` + RateLimitTokenRefresh float64 `split_words:"true" default:"150"` + RateLimitSso float64 `split_words:"true" default:"30"` + RateLimitAnonymousUsers float64 `split_words:"true" default:"30"` + RateLimitOtp float64 `split_words:"true" default:"30"` + + SiteURL string `json:"site_url" split_words:"true" required:"true"` + URIAllowList []string `json:"uri_allow_list" split_words:"true"` + URIAllowListMap map[string]glob.Glob + Password PasswordConfiguration `json:"password"` + JWT JWTConfiguration `json:"jwt"` + Mailer MailerConfiguration `json:"mailer"` + Sms SmsProviderConfiguration `json:"sms"` + DisableSignup bool `json:"disable_signup" split_words:"true"` + Hook HookConfiguration `json:"hook" split_words:"true"` + Security SecurityConfiguration `json:"security"` + Sessions SessionsConfiguration `json:"sessions"` + MFA MFAConfiguration `json:"MFA"` + SAML SAMLConfiguration `json:"saml"` + CORS CORSConfiguration `json:"cors"` +} + +type CORSConfiguration struct { + AllowedHeaders []string `json:"allowed_headers" split_words:"true"` +} + +func (c *CORSConfiguration) AllAllowedHeaders(defaults []string) []string { + set := make(map[string]bool) + for _, header := range defaults { + set[header] = true + } + + var result []string + result = append(result, defaults...) + + for _, header := range c.AllowedHeaders { + if !set[header] { + result = append(result, header) + } + + set[header] = true + } + + return result +} + +// EmailContentConfiguration holds the configuration for emails, both subjects and template URLs. +type EmailContentConfiguration struct { + Invite string `json:"invite"` + Confirmation string `json:"confirmation"` + Recovery string `json:"recovery"` + EmailChange string `json:"email_change" split_words:"true"` + MagicLink string `json:"magic_link" split_words:"true"` + Reauthentication string `json:"reauthentication"` +} + +type ProviderConfiguration struct { + AnonymousUsers AnonymousProviderConfiguration `json:"anonymous_users" split_words:"true"` + Apple OAuthProviderConfiguration `json:"apple"` + Azure OAuthProviderConfiguration `json:"azure"` + Bitbucket OAuthProviderConfiguration `json:"bitbucket"` + Discord OAuthProviderConfiguration `json:"discord"` + Facebook OAuthProviderConfiguration `json:"facebook"` + Figma OAuthProviderConfiguration `json:"figma"` + Fly OAuthProviderConfiguration `json:"fly"` + Github OAuthProviderConfiguration `json:"github"` + Gitlab OAuthProviderConfiguration `json:"gitlab"` + Google OAuthProviderConfiguration `json:"google"` + Kakao OAuthProviderConfiguration `json:"kakao"` + Notion OAuthProviderConfiguration `json:"notion"` + Keycloak OAuthProviderConfiguration `json:"keycloak"` + Linkedin OAuthProviderConfiguration `json:"linkedin"` + LinkedinOIDC OAuthProviderConfiguration `json:"linkedin_oidc" envconfig:"LINKEDIN_OIDC"` + Spotify OAuthProviderConfiguration `json:"spotify"` + Slack OAuthProviderConfiguration `json:"slack"` + SlackOIDC OAuthProviderConfiguration `json:"slack_oidc" envconfig:"SLACK_OIDC"` + Twitter OAuthProviderConfiguration `json:"twitter"` + Twitch OAuthProviderConfiguration `json:"twitch"` + VercelMarketplace OAuthProviderConfiguration `json:"vercel_marketplace" split_words:"true"` + WorkOS OAuthProviderConfiguration `json:"workos"` + Email EmailProviderConfiguration `json:"email"` + Phone PhoneProviderConfiguration `json:"phone"` + Zoom OAuthProviderConfiguration `json:"zoom"` + IosBundleId string `json:"ios_bundle_id" split_words:"true"` + RedirectURL string `json:"redirect_url"` + AllowedIdTokenIssuers []string `json:"allowed_id_token_issuers" split_words:"true"` + FlowStateExpiryDuration time.Duration `json:"flow_state_expiry_duration" split_words:"true"` +} + +type SMTPConfiguration struct { + MaxFrequency time.Duration `json:"max_frequency" split_words:"true"` + Host string `json:"host"` + Port int `json:"port,omitempty" default:"587"` + User string `json:"user"` + Pass string `json:"pass,omitempty"` + AdminEmail string `json:"admin_email" split_words:"true"` + SenderName string `json:"sender_name" split_words:"true"` + Headers string `json:"headers"` + LoggingEnabled bool `json:"logging_enabled" split_words:"true" default:"false"` + + fromAddress string `json:"-"` + normalizedHeaders map[string][]string `json:"-"` +} + +func (c *SMTPConfiguration) Validate() error { + headers := make(map[string][]string) + + if c.Headers != "" { + err := json.Unmarshal([]byte(c.Headers), &headers) + if err != nil { + return fmt.Errorf("conf: SMTP headers not a map[string][]string format: %w", err) + } + } + + if len(headers) > 0 { + c.normalizedHeaders = headers + } + + mail := gomail.NewMessage() + + c.fromAddress = mail.FormatAddress(c.AdminEmail, c.SenderName) + + return nil +} + +func (c *SMTPConfiguration) FromAddress() string { + return c.fromAddress +} + +func (c *SMTPConfiguration) NormalizedHeaders() map[string][]string { + return c.normalizedHeaders +} + +type MailerConfiguration struct { + Autoconfirm bool `json:"autoconfirm"` + AllowUnverifiedEmailSignIns bool `json:"allow_unverified_email_sign_ins" split_words:"true" default:"false"` + + Subjects EmailContentConfiguration `json:"subjects"` + Templates EmailContentConfiguration `json:"templates"` + URLPaths EmailContentConfiguration `json:"url_paths"` + + SecureEmailChangeEnabled bool `json:"secure_email_change_enabled" split_words:"true" default:"true"` + + OtpExp uint `json:"otp_exp" split_words:"true"` + OtpLength int `json:"otp_length" split_words:"true"` + + ExternalHosts []string `json:"external_hosts" split_words:"true"` + + // EXPERIMENTAL: May be removed in a future release. + EmailValidationExtended bool `json:"email_validation_extended" split_words:"true" default:"false"` + EmailValidationServiceURL string `json:"email_validation_service_url" split_words:"true"` + EmailValidationServiceHeaders string `json:"email_validation_service_headers" split_words:"true"` + + serviceHeaders map[string][]string `json:"-"` +} + +func (c *MailerConfiguration) Validate() error { + headers := make(map[string][]string) + + if c.EmailValidationServiceHeaders != "" { + err := json.Unmarshal([]byte(c.EmailValidationServiceHeaders), &headers) + if err != nil { + return fmt.Errorf("conf: mailer validation headers not a map[string][]string format: %w", err) + } + } + + if len(headers) > 0 { + c.serviceHeaders = headers + } + return nil +} + +func (c *MailerConfiguration) GetEmailValidationServiceHeaders() map[string][]string { + return c.serviceHeaders +} + +type PhoneProviderConfiguration struct { + Enabled bool `json:"enabled" default:"false"` +} + +type SmsProviderConfiguration struct { + Autoconfirm bool `json:"autoconfirm"` + MaxFrequency time.Duration `json:"max_frequency" split_words:"true"` + OtpExp uint `json:"otp_exp" split_words:"true"` + OtpLength int `json:"otp_length" split_words:"true"` + Provider string `json:"provider"` + Template string `json:"template"` + TestOTP map[string]string `json:"test_otp" split_words:"true"` + TestOTPValidUntil Time `json:"test_otp_valid_until" split_words:"true"` + SMSTemplate *template.Template `json:"-"` + + Twilio TwilioProviderConfiguration `json:"twilio"` + TwilioVerify TwilioVerifyProviderConfiguration `json:"twilio_verify" split_words:"true"` + Messagebird MessagebirdProviderConfiguration `json:"messagebird"` + Textlocal TextlocalProviderConfiguration `json:"textlocal"` + Vonage VonageProviderConfiguration `json:"vonage"` +} + +func (c *SmsProviderConfiguration) GetTestOTP(phone string, now time.Time) (string, bool) { + if c.TestOTP != nil && (c.TestOTPValidUntil.Time.IsZero() || now.Before(c.TestOTPValidUntil.Time)) { + testOTP, ok := c.TestOTP[phone] + return testOTP, ok + } + + return "", false +} + +type TwilioProviderConfiguration struct { + AccountSid string `json:"account_sid" split_words:"true"` + AuthToken string `json:"auth_token" split_words:"true"` + MessageServiceSid string `json:"message_service_sid" split_words:"true"` + ContentSid string `json:"content_sid" split_words:"true"` +} + +type TwilioVerifyProviderConfiguration struct { + AccountSid string `json:"account_sid" split_words:"true"` + AuthToken string `json:"auth_token" split_words:"true"` + MessageServiceSid string `json:"message_service_sid" split_words:"true"` +} + +type MessagebirdProviderConfiguration struct { + AccessKey string `json:"access_key" split_words:"true"` + Originator string `json:"originator" split_words:"true"` +} + +type TextlocalProviderConfiguration struct { + ApiKey string `json:"api_key" split_words:"true"` + Sender string `json:"sender" split_words:"true"` +} + +type VonageProviderConfiguration struct { + ApiKey string `json:"api_key" split_words:"true"` + ApiSecret string `json:"api_secret" split_words:"true"` + From string `json:"from" split_words:"true"` +} + +type CaptchaConfiguration struct { + Enabled bool `json:"enabled" default:"false"` + Provider string `json:"provider" default:"hcaptcha"` + Secret string `json:"provider_secret"` +} + +func (c *CaptchaConfiguration) Validate() error { + if !c.Enabled { + return nil + } + + if c.Provider != "hcaptcha" && c.Provider != "turnstile" { + return fmt.Errorf("unsupported captcha provider: %s", c.Provider) + } + + c.Secret = strings.TrimSpace(c.Secret) + + if c.Secret == "" { + return errors.New("captcha provider secret is empty") + } + + return nil +} + +// DatabaseEncryptionConfiguration configures Auth to encrypt certain columns. +// Once Encrypt is set to true, data will start getting encrypted with the +// provided encryption key. Setting it to false just stops encryption from +// going on further, but DecryptionKeys would have to contain the same key so +// the encrypted data remains accessible. +type DatabaseEncryptionConfiguration struct { + Encrypt bool `json:"encrypt"` + + EncryptionKeyID string `json:"encryption_key_id" split_words:"true"` + EncryptionKey string `json:"-" split_words:"true"` + + DecryptionKeys map[string]string `json:"-" split_words:"true"` +} + +func (c *DatabaseEncryptionConfiguration) Validate() error { + if c.Encrypt { + if c.EncryptionKeyID == "" { + return errors.New("conf: encryption key ID must be specified") + } + + decodedKey, err := base64.RawURLEncoding.DecodeString(c.EncryptionKey) + if err != nil { + return err + } + + if len(decodedKey) != 256/8 { + return errors.New("conf: encryption key is not 256 bits") + } + + if c.DecryptionKeys == nil || c.DecryptionKeys[c.EncryptionKeyID] == "" { + return errors.New("conf: encryption key must also be present in decryption keys") + } + } + + for id, key := range c.DecryptionKeys { + decodedKey, err := base64.RawURLEncoding.DecodeString(key) + if err != nil { + return err + } + + if len(decodedKey) != 256/8 { + return fmt.Errorf("conf: decryption key with ID %q must be 256 bits", id) + } + } + + return nil +} + +type SecurityConfiguration struct { + Captcha CaptchaConfiguration `json:"captcha"` + RefreshTokenRotationEnabled bool `json:"refresh_token_rotation_enabled" split_words:"true" default:"true"` + RefreshTokenReuseInterval int `json:"refresh_token_reuse_interval" split_words:"true"` + UpdatePasswordRequireReauthentication bool `json:"update_password_require_reauthentication" split_words:"true"` + ManualLinkingEnabled bool `json:"manual_linking_enabled" split_words:"true" default:"false"` + + DBEncryption DatabaseEncryptionConfiguration `json:"database_encryption" split_words:"true"` +} + +func (c *SecurityConfiguration) Validate() error { + if err := c.Captcha.Validate(); err != nil { + return err + } + + if err := c.DBEncryption.Validate(); err != nil { + return err + } + + return nil +} + +func loadEnvironment(filename string) error { + var err error + if filename != "" { + err = godotenv.Overload(filename) + } else { + err = godotenv.Load() + // handle if .env file does not exist, this is OK + if os.IsNotExist(err) { + return nil + } + } + return err +} + +// Moving away from the existing HookConfig so we can get a fresh start. +type HookConfiguration struct { + MFAVerificationAttempt ExtensibilityPointConfiguration `json:"mfa_verification_attempt" split_words:"true"` + PasswordVerificationAttempt ExtensibilityPointConfiguration `json:"password_verification_attempt" split_words:"true"` + CustomAccessToken ExtensibilityPointConfiguration `json:"custom_access_token" split_words:"true"` + SendEmail ExtensibilityPointConfiguration `json:"send_email" split_words:"true"` + SendSMS ExtensibilityPointConfiguration `json:"send_sms" split_words:"true"` +} + +type HTTPHookSecrets []string + +func (h *HTTPHookSecrets) Decode(value string) error { + parts := strings.Split(value, "|") + for _, part := range parts { + if part != "" { + *h = append(*h, part) + } + } + + return nil +} + +type ExtensibilityPointConfiguration struct { + URI string `json:"uri"` + Enabled bool `json:"enabled"` + // For internal use together with Postgres Hook. Not publicly exposed. + HookName string `json:"-"` + // We use | as a separator for keys and : as a separator for keys within a keypair. For instance: v1,whsec_test|v1a,whpk_myother:v1a,whsk_testkey|v1,whsec_secret3 + HTTPHookSecrets HTTPHookSecrets `json:"secrets" envconfig:"secrets"` +} + +func (h *HookConfiguration) Validate() error { + points := []ExtensibilityPointConfiguration{ + h.MFAVerificationAttempt, + h.PasswordVerificationAttempt, + h.CustomAccessToken, + h.SendSMS, + h.SendEmail, + } + for _, point := range points { + if err := point.ValidateExtensibilityPoint(); err != nil { + return err + } + } + return nil +} + +func (e *ExtensibilityPointConfiguration) ValidateExtensibilityPoint() error { + if e.URI == "" { + return nil + } + u, err := url.Parse(e.URI) + if err != nil { + return err + } + switch strings.ToLower(u.Scheme) { + case "pg-functions": + return validatePostgresPath(u) + case "http": + hostname := u.Hostname() + if hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1" || hostname == "host.docker.internal" { + return validateHTTPHookSecrets(e.HTTPHookSecrets) + } + return fmt.Errorf("only localhost, 127.0.0.1, and ::1 are supported with http") + case "https": + return validateHTTPHookSecrets(e.HTTPHookSecrets) + default: + return fmt.Errorf("only postgres hooks and HTTPS functions are supported at the moment") + } +} + +func validatePostgresPath(u *url.URL) error { + pathParts := strings.Split(u.Path, "/") + if len(pathParts) < 3 { + return fmt.Errorf("URI path does not contain enough parts") + } + + schema := pathParts[1] + table := pathParts[2] + // Validate schema and table names + if !postgresNamesRegexp.MatchString(schema) { + return fmt.Errorf("invalid schema name: %s", schema) + } + if !postgresNamesRegexp.MatchString(table) { + return fmt.Errorf("invalid table name: %s", table) + } + return nil +} + +func isValidSecretFormat(secret string) bool { + return symmetricSecretFormat.MatchString(secret) || asymmetricSecretFormat.MatchString(secret) +} + +func validateHTTPHookSecrets(secrets []string) error { + for _, secret := range secrets { + if !isValidSecretFormat(secret) { + return fmt.Errorf("invalid secret format") + } + } + return nil +} + +func (e *ExtensibilityPointConfiguration) PopulateExtensibilityPoint() error { + u, err := url.Parse(e.URI) + if err != nil { + return err + } + if u.Scheme == "pg-functions" { + pathParts := strings.Split(u.Path, "/") + e.HookName = fmt.Sprintf("%q.%q", pathParts[1], pathParts[2]) + } + return nil +} + +// LoadFile calls godotenv.Load() when the given filename is empty ignoring any +// errors loading, otherwise it calls godotenv.Overload(filename). +// +// godotenv.Load: preserves env, ".env" path is optional +// godotenv.Overload: overrides env, "filename" path must exist +func LoadFile(filename string) error { + var err error + if filename != "" { + err = godotenv.Overload(filename) + } else { + err = godotenv.Load() + // handle if .env file does not exist, this is OK + if os.IsNotExist(err) { + return nil + } + } + return err +} + +// LoadDirectory does nothing when configDir is empty, otherwise it will attempt +// to load a list of configuration files located in configDir by using ReadDir +// to obtain a sorted list of files containing a .env suffix. +// +// When the list is empty it will do nothing, otherwise it passes the file list +// to godotenv.Overload to pull them into the current environment. +func LoadDirectory(configDir string) error { + if configDir == "" { + return nil + } + + // Returns entries sorted by filename + ents, err := os.ReadDir(configDir) + if err != nil { + // We mimic the behavior of LoadGlobal here, if an explicit path is + // provided we return an error. + return err + } + + var paths []string + for _, ent := range ents { + if ent.IsDir() { + continue // ignore directories + } + + // We only read files ending in .env + name := ent.Name() + if !strings.HasSuffix(name, ".env") { + continue + } + + // ent.Name() does not include the watch dir. + paths = append(paths, filepath.Join(configDir, name)) + } + + // If at least one path was found we load the configuration files in the + // directory. We don't call override without config files because it will + // override the env vars previously set with a ".env", if one exists. + if len(paths) > 0 { + if err := godotenv.Overload(paths...); err != nil { + return err + } + } + return nil +} + +// LoadGlobalFromEnv will return a new *GlobalConfiguration value from the +// currently configured environment. +func LoadGlobalFromEnv() (*GlobalConfiguration, error) { + config := new(GlobalConfiguration) + if err := loadGlobal(config); err != nil { + return nil, err + } + return config, nil +} + +func LoadGlobal(filename string) (*GlobalConfiguration, error) { + if err := loadEnvironment(filename); err != nil { + return nil, err + } + + config := new(GlobalConfiguration) + if err := loadGlobal(config); err != nil { + return nil, err + } + return config, nil +} + +func loadGlobal(config *GlobalConfiguration) error { + // although the package is called "auth" it used to be called "gotrue" + // so environment configs will remain to be called "GOTRUE" + if err := envconfig.Process("gotrue", config); err != nil { + return err + } + + if err := config.ApplyDefaults(); err != nil { + return err + } + + if err := config.Validate(); err != nil { + return err + } + + if config.Hook.PasswordVerificationAttempt.Enabled { + if err := config.Hook.PasswordVerificationAttempt.PopulateExtensibilityPoint(); err != nil { + return err + } + } + + if config.Hook.SendSMS.Enabled { + if err := config.Hook.SendSMS.PopulateExtensibilityPoint(); err != nil { + return err + } + } + if config.Hook.SendEmail.Enabled { + if err := config.Hook.SendEmail.PopulateExtensibilityPoint(); err != nil { + return err + } + } + + if config.Hook.MFAVerificationAttempt.Enabled { + if err := config.Hook.MFAVerificationAttempt.PopulateExtensibilityPoint(); err != nil { + return err + } + } + + if config.Hook.CustomAccessToken.Enabled { + if err := config.Hook.CustomAccessToken.PopulateExtensibilityPoint(); err != nil { + return err + } + } + + if config.SAML.Enabled { + if err := config.SAML.PopulateFields(config.API.ExternalURL); err != nil { + return err + } + } else { + config.SAML.PrivateKey = "" + } + + if config.Sms.Provider != "" { + SMSTemplate := config.Sms.Template + if SMSTemplate == "" { + SMSTemplate = "Your code is {{ .Code }}" + } + template, err := template.New("").Parse(SMSTemplate) + if err != nil { + return err + } + config.Sms.SMSTemplate = template + } + + if config.MFA.Phone.EnrollEnabled || config.MFA.Phone.VerifyEnabled { + smsTemplate := config.MFA.Phone.Template + if smsTemplate == "" { + smsTemplate = "Your code is {{ .Code }}" + } + template, err := template.New("").Parse(smsTemplate) + if err != nil { + return err + } + config.MFA.Phone.SMSTemplate = template + } + + return nil +} + +// ApplyDefaults sets defaults for a GlobalConfiguration +func (config *GlobalConfiguration) ApplyDefaults() error { + if config.JWT.AdminGroupName == "" { + config.JWT.AdminGroupName = "admin" + } + + if len(config.JWT.AdminRoles) == 0 { + config.JWT.AdminRoles = []string{"service_role", "supabase_admin"} + } + + if config.JWT.Exp == 0 { + config.JWT.Exp = 3600 + } + + if len(config.JWT.Keys) == 0 { + // transform the secret into a JWK for consistency + privKey, err := jwk.FromRaw([]byte(config.JWT.Secret)) + if err != nil { + return err + } + if config.JWT.KeyID != "" { + if err := privKey.Set(jwk.KeyIDKey, config.JWT.KeyID); err != nil { + return err + } + } + if privKey.Algorithm().String() == "" { + if err := privKey.Set(jwk.AlgorithmKey, jwt.SigningMethodHS256.Name); err != nil { + return err + } + } + if err := privKey.Set(jwk.KeyUsageKey, "sig"); err != nil { + return err + } + if len(privKey.KeyOps()) == 0 { + if err := privKey.Set(jwk.KeyOpsKey, jwk.KeyOperationList{jwk.KeyOpSign, jwk.KeyOpVerify}); err != nil { + return err + } + } + pubKey, err := privKey.PublicKey() + if err != nil { + return err + } + config.JWT.Keys = make(JwtKeysDecoder) + config.JWT.Keys[config.JWT.KeyID] = JwkInfo{ + PublicKey: pubKey, + PrivateKey: privKey, + } + } + + if config.JWT.ValidMethods == nil { + config.JWT.ValidMethods = []string{} + for _, key := range config.JWT.Keys { + alg := GetSigningAlg(key.PublicKey) + config.JWT.ValidMethods = append(config.JWT.ValidMethods, alg.Alg()) + } + + } + + if config.Mailer.Autoconfirm && config.Mailer.AllowUnverifiedEmailSignIns { + return errors.New("cannot enable both GOTRUE_MAILER_AUTOCONFIRM and GOTRUE_MAILER_ALLOW_UNVERIFIED_EMAIL_SIGN_INS") + } + + if config.Mailer.URLPaths.Invite == "" { + config.Mailer.URLPaths.Invite = "/verify" + } + + if config.Mailer.URLPaths.Confirmation == "" { + config.Mailer.URLPaths.Confirmation = "/verify" + } + + if config.Mailer.URLPaths.Recovery == "" { + config.Mailer.URLPaths.Recovery = "/verify" + } + + if config.Mailer.URLPaths.EmailChange == "" { + config.Mailer.URLPaths.EmailChange = "/verify" + } + + if config.Mailer.OtpExp == 0 { + config.Mailer.OtpExp = 86400 // 1 day + } + + if config.Mailer.OtpLength == 0 || config.Mailer.OtpLength < 6 || config.Mailer.OtpLength > 10 { + // 6-digit otp by default + config.Mailer.OtpLength = 6 + } + + if config.SMTP.MaxFrequency == 0 { + config.SMTP.MaxFrequency = 1 * time.Minute + } + + if config.Sms.MaxFrequency == 0 { + config.Sms.MaxFrequency = 1 * time.Minute + } + + if config.Sms.OtpExp == 0 { + config.Sms.OtpExp = 60 + } + + if config.Sms.OtpLength == 0 || config.Sms.OtpLength < 6 || config.Sms.OtpLength > 10 { + // 6-digit otp by default + config.Sms.OtpLength = 6 + } + + if config.Sms.TestOTP != nil { + formatTestOtps := make(map[string]string) + for phone, otp := range config.Sms.TestOTP { + phone = strings.ReplaceAll(strings.TrimPrefix(phone, "+"), " ", "") + formatTestOtps[phone] = otp + } + config.Sms.TestOTP = formatTestOtps + } + + if len(config.Sms.Template) == 0 { + config.Sms.Template = "" + } + + if config.URIAllowList == nil { + config.URIAllowList = []string{} + } + + if config.URIAllowList != nil { + config.URIAllowListMap = make(map[string]glob.Glob) + for _, uri := range config.URIAllowList { + g := glob.MustCompile(uri, '.', '/') + config.URIAllowListMap[uri] = g + } + } + + if config.Password.MinLength < defaultMinPasswordLength { + config.Password.MinLength = defaultMinPasswordLength + } + + if config.MFA.ChallengeExpiryDuration < defaultChallengeExpiryDuration { + config.MFA.ChallengeExpiryDuration = defaultChallengeExpiryDuration + } + + if config.MFA.FactorExpiryDuration < defaultFactorExpiryDuration { + config.MFA.FactorExpiryDuration = defaultFactorExpiryDuration + } + + if config.MFA.Phone.MaxFrequency == 0 { + config.MFA.Phone.MaxFrequency = 1 * time.Minute + } + + if config.MFA.Phone.OtpLength < 6 || config.MFA.Phone.OtpLength > 10 { + // 6-digit otp by default + config.MFA.Phone.OtpLength = 6 + } + + if config.External.FlowStateExpiryDuration < defaultFlowStateExpiryDuration { + config.External.FlowStateExpiryDuration = defaultFlowStateExpiryDuration + } + + if len(config.External.AllowedIdTokenIssuers) == 0 { + config.External.AllowedIdTokenIssuers = append(config.External.AllowedIdTokenIssuers, "https://appleid.apple.com", "https://accounts.google.com") + } + + return nil +} + +// Validate validates all of configuration. +func (c *GlobalConfiguration) Validate() error { + validatables := []interface { + Validate() error + }{ + &c.API, + &c.DB, + &c.Tracing, + &c.Metrics, + &c.SMTP, + &c.Mailer, + &c.SAML, + &c.Security, + &c.Sessions, + &c.Hook, + &c.JWT.Keys, + } + + for _, validatable := range validatables { + if err := validatable.Validate(); err != nil { + return err + } + } + + return nil +} + +func (o *OAuthProviderConfiguration) ValidateOAuth() error { + if !o.Enabled { + return errors.New("provider is not enabled") + } + if len(o.ClientID) == 0 { + return errors.New("missing OAuth client ID") + } + if o.Secret == "" { + return errors.New("missing OAuth secret") + } + if o.RedirectURI == "" { + return errors.New("missing redirect URI") + } + return nil +} + +func (t *TwilioProviderConfiguration) Validate() error { + if t.AccountSid == "" { + return errors.New("missing Twilio account SID") + } + if t.AuthToken == "" { + return errors.New("missing Twilio auth token") + } + if t.MessageServiceSid == "" { + return errors.New("missing Twilio message service SID or Twilio phone number") + } + return nil +} + +func (t *TwilioVerifyProviderConfiguration) Validate() error { + if t.AccountSid == "" { + return errors.New("missing Twilio account SID") + } + if t.AuthToken == "" { + return errors.New("missing Twilio auth token") + } + if t.MessageServiceSid == "" { + return errors.New("missing Twilio message service SID or Twilio phone number") + } + return nil +} + +func (t *MessagebirdProviderConfiguration) Validate() error { + if t.AccessKey == "" { + return errors.New("missing Messagebird access key") + } + if t.Originator == "" { + return errors.New("missing Messagebird originator") + } + return nil +} + +func (t *TextlocalProviderConfiguration) Validate() error { + if t.ApiKey == "" { + return errors.New("missing Textlocal API key") + } + if t.Sender == "" { + return errors.New("missing Textlocal sender") + } + return nil +} + +func (t *VonageProviderConfiguration) Validate() error { + if t.ApiKey == "" { + return errors.New("missing Vonage API key") + } + if t.ApiSecret == "" { + return errors.New("missing Vonage API secret") + } + if t.From == "" { + return errors.New("missing Vonage 'from' parameter") + } + return nil +} + +func (t *SmsProviderConfiguration) IsTwilioVerifyProvider() bool { + return t.Provider == "twilio_verify" +} diff --git a/auth_v2.169.0/internal/conf/configuration_test.go b/auth_v2.169.0/internal/conf/configuration_test.go new file mode 100644 index 0000000..c03f954 --- /dev/null +++ b/auth_v2.169.0/internal/conf/configuration_test.go @@ -0,0 +1,246 @@ +package conf + +import ( + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMain(m *testing.M) { + defer os.Clearenv() + os.Exit(m.Run()) +} + +func TestGlobal(t *testing.T) { + os.Setenv("GOTRUE_SITE_URL", "http://localhost:8080") + os.Setenv("GOTRUE_DB_DRIVER", "postgres") + os.Setenv("GOTRUE_DB_DATABASE_URL", "fake") + os.Setenv("GOTRUE_OPERATOR_TOKEN", "token") + os.Setenv("GOTRUE_API_REQUEST_ID_HEADER", "X-Request-ID") + os.Setenv("GOTRUE_JWT_SECRET", "secret") + os.Setenv("API_EXTERNAL_URL", "http://localhost:9999") + os.Setenv("GOTRUE_HOOK_MFA_VERIFICATION_ATTEMPT_URI", "pg-functions://postgres/auth/count_failed_attempts") + os.Setenv("GOTRUE_HOOK_SEND_SMS_SECRETS", "v1,whsec_aWxpa2VzdXBhYmFzZXZlcnltdWNoYW5kaWhvcGV5b3Vkb3Rvbw==") + os.Setenv("GOTRUE_SMTP_HEADERS", `{"X-PM-Metadata-project-ref":["project_ref"],"X-SES-Message-Tags":["ses:feedback-id-a=project_ref,ses:feedback-id-b=$messageType"]}`) + os.Setenv("GOTRUE_MAILER_EMAIL_VALIDATION_SERVICE_HEADERS", `{"apikey":["test"]}`) + os.Setenv("GOTRUE_SMTP_LOGGING_ENABLED", "true") + gc, err := LoadGlobal("") + require.NoError(t, err) + assert.Equal(t, true, gc.SMTP.LoggingEnabled) + assert.Equal(t, "project_ref", gc.SMTP.NormalizedHeaders()["X-PM-Metadata-project-ref"][0]) + require.NotNil(t, gc) + assert.Equal(t, "X-Request-ID", gc.API.RequestIDHeader) + assert.Equal(t, "pg-functions://postgres/auth/count_failed_attempts", gc.Hook.MFAVerificationAttempt.URI) + + { + hdrs := gc.Mailer.GetEmailValidationServiceHeaders() + assert.Equal(t, 1, len(hdrs["apikey"])) + assert.Equal(t, "test", hdrs["apikey"][0]) + } + +} + +func TestRateLimits(t *testing.T) { + { + os.Setenv("GOTRUE_RATE_LIMIT_EMAIL_SENT", "0/1h") + + gc, err := LoadGlobal("") + require.NoError(t, err) + assert.Equal(t, float64(0), gc.RateLimitEmailSent.Events) + assert.Equal(t, time.Hour, gc.RateLimitEmailSent.OverTime) + } + + { + os.Setenv("GOTRUE_RATE_LIMIT_EMAIL_SENT", "10/1h") + + gc, err := LoadGlobal("") + require.NoError(t, err) + assert.Equal(t, float64(10), gc.RateLimitEmailSent.Events) + assert.Equal(t, time.Hour, gc.RateLimitEmailSent.OverTime) + } +} + +func TestPasswordRequiredCharactersDecode(t *testing.T) { + examples := []struct { + Value string + Result []string + }{ + { + Value: "a:b:c", + Result: []string{ + "a", + "b", + "c", + }, + }, + { + Value: "a\\:b:c", + Result: []string{ + "a:b", + "c", + }, + }, + { + Value: "a:b\\:c", + Result: []string{ + "a", + "b:c", + }, + }, + { + Value: "\\:a:b:c", + Result: []string{ + ":a", + "b", + "c", + }, + }, + { + Value: "a:b:c\\:", + Result: []string{ + "a", + "b", + "c:", + }, + }, + { + Value: "::\\::", + Result: []string{ + ":", + }, + }, + { + Value: "", + Result: nil, + }, + { + Value: " ", + Result: []string{ + " ", + }, + }, + } + + for i, example := range examples { + var into PasswordRequiredCharacters + require.NoError(t, into.Decode(example.Value), "Example %d failed with error", i) + + require.Equal(t, []string(into), example.Result, "Example %d got unexpected result", i) + } +} + +func TestHTTPHookSecretsDecode(t *testing.T) { + examples := []struct { + Value string + Result []string + }{ + { + Value: "v1,whsec_secret1|v1a,whpk_secrets:whsk_secret2|v1,whsec_secret3", + Result: []string{"v1,whsec_secret1", "v1a,whpk_secrets:whsk_secret2", "v1,whsec_secret3"}, + }, + { + Value: "v1,whsec_singlesecret", + Result: []string{"v1,whsec_singlesecret"}, + }, + { + Value: " ", + Result: []string{" "}, + }, + { + Value: "", + Result: nil, + }, + { + Value: "|a|b|c", + Result: []string{ + "a", + "b", + "c", + }, + }, + { + Value: "||||", + Result: nil, + }, + { + Value: "::", + Result: []string{"::"}, + }, + { + Value: "secret1::secret3", + Result: []string{"secret1::secret3"}, + }, + } + + for i, example := range examples { + var into HTTPHookSecrets + + require.NoError(t, into.Decode(example.Value), "Example %d failed with error", i) + require.Equal(t, []string(into), example.Result, "Example %d got unexpected result", i) + } +} + +func TestValidateExtensibilityPointURI(t *testing.T) { + cases := []struct { + desc string + uri string + expectError bool + }{ + // Positive test cases + {desc: "Valid HTTPS URI", uri: "https://asdfgggqqwwerty.website.co/functions/v1/custom-sms-sender", expectError: false}, + {desc: "Valid HTTPS URI", uri: "HTTPS://www.asdfgggqqwwerty.website.co/functions/v1/custom-sms-sender", expectError: false}, + {desc: "Valid Postgres URI", uri: "pg-functions://postgres/auth/verification_hook_reject", expectError: false}, + {desc: "Another Valid URI", uri: "pg-functions://postgres/user_management/add_user", expectError: false}, + {desc: "Another Valid URI", uri: "pg-functions://postgres/MySpeCial/FUNCTION_THAT_YELLS_AT_YOU", expectError: false}, + {desc: "Valid HTTP URI", uri: "http://localhost/functions/v1/custom-sms-sender", expectError: false}, + + // Negative test cases + {desc: "Invalid HTTP URI", uri: "http://asdfgggg.website.co/functions/v1/custom-sms-sender", expectError: true}, + {desc: "Invalid HTTPS URI (HTTP)", uri: "http://asdfgggqqwwerty.supabase.co/functions/v1/custom-sms-sender", expectError: true}, + {desc: "Invalid Schema Name", uri: "pg-functions://postgres/123auth/verification_hook_reject", expectError: true}, + {desc: "Invalid Function Name", uri: "pg-functions://postgres/auth/123verification_hook_reject", expectError: true}, + {desc: "Insufficient Path Parts", uri: "pg-functions://postgres/auth", expectError: true}, + } + + for _, tc := range cases { + ep := ExtensibilityPointConfiguration{URI: tc.uri} + err := ep.ValidateExtensibilityPoint() + if tc.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + } +} + +func TestValidateExtensibilityPointSecrets(t *testing.T) { + validHTTPSURI := "https://asdfgggqqwwerty.website.co/functions/v1/custom-sms-sender" + cases := []struct { + desc string + secret []string + expectError bool + }{ + // Positive test cases + {desc: "Valid Symmetric Secret", secret: []string{"v1,whsec_NDYzODhlNTY0ZGI1OWZjYTU2NjMwN2FhYzM3YzBkMWQ0NzVjNWRkNTJmZDU0MGNhYTAzMjVjNjQzMzE3Mjk2Zg====="}, expectError: false}, + {desc: "Valid Asymmetric Secret", secret: []string{"v1a,whpk_NDYzODhlNTY0ZGI1OWZjYTU2NjMwN2FhYzM3YzBkMWQ0NzVjNWRkNTJmZDU0MGNhYTAzMjVjNjQzMzE3Mjk2Zg==:whsk_abc889a6b1160015025064f108a48d6aba1c7c95fa8e304b4d225e8ae0121511"}, expectError: false}, + {desc: "Valid Mix of Symmetric and asymmetric Secret", secret: []string{"v1,whsec_2b49264c90fd15db3bb0e05f4e1547b9c183eb06d585be8a", "v1a,whpk_46388e564db59fca566307aac37c0d1d475c5dd52fd540caa0325c643317296f:whsk_YWJjODg5YTZiMTE2MDAxNTAyNTA2NGYxMDhhNDhkNmFiYTFjN2M5NWZhOGUzMDRiNGQyMjVlOGFlMDEyMTUxMSI="}, expectError: false}, + + // Negative test cases + {desc: "Invalid Asymmetric Secret", secret: []string{"v1a,john:jill", "jill"}, expectError: true}, + {desc: "Invalid Symmetric Secret", secret: []string{"tommy"}, expectError: true}, + } + for _, tc := range cases { + ep := ExtensibilityPointConfiguration{URI: validHTTPSURI, HTTPHookSecrets: tc.secret} + err := ep.ValidateExtensibilityPoint() + if tc.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + + } + +} diff --git a/auth_v2.169.0/internal/conf/jwk.go b/auth_v2.169.0/internal/conf/jwk.go new file mode 100644 index 0000000..fffb0c2 --- /dev/null +++ b/auth_v2.169.0/internal/conf/jwk.go @@ -0,0 +1,150 @@ +package conf + +import ( + "encoding/json" + "fmt" + + "github.com/golang-jwt/jwt/v5" + "github.com/lestrrat-go/jwx/v2/jwk" +) + +type JwtKeysDecoder map[string]JwkInfo + +type JwkInfo struct { + PublicKey jwk.Key `json:"public_key"` + PrivateKey jwk.Key `json:"private_key"` +} + +// Decode implements the Decoder interface +func (j *JwtKeysDecoder) Decode(value string) error { + data := make([]json.RawMessage, 0) + if err := json.Unmarshal([]byte(value), &data); err != nil { + return err + } + + config := JwtKeysDecoder{} + for _, key := range data { + privJwk, err := jwk.ParseKey(key) + if err != nil { + return err + } + pubJwk, err := jwk.PublicKeyOf(privJwk) + if err != nil { + return err + } + + // all public keys should have the the use claim set to 'sig + if err := pubJwk.Set(jwk.KeyUsageKey, "sig"); err != nil { + return err + } + + // all public keys should only have 'verify' set as the key_ops + if err := pubJwk.Set(jwk.KeyOpsKey, jwk.KeyOperationList{jwk.KeyOpVerify}); err != nil { + return err + } + + config[pubJwk.KeyID()] = JwkInfo{ + PublicKey: pubJwk, + PrivateKey: privJwk, + } + } + *j = config + return nil +} + +func (j *JwtKeysDecoder) Validate() error { + // Validate performs _minimal_ checks if the data stored in the key are valid. + // By minimal, we mean that it does not check if the key is valid for use in + // cryptographic operations. For example, it does not check if an RSA key's + // `e` field is a valid exponent, or if the `n` field is a valid modulus. + // Instead, it checks for things such as the _presence_ of some required fields, + // or if certain keys' values are of particular length. + // + // Note that depending on the underlying key type, use of this method requires + // that multiple fields in the key are properly populated. For example, an EC + // key's "x", "y" fields cannot be validated unless the "crv" field is populated first. + signingKeys := []jwk.Key{} + for _, key := range *j { + if err := key.PrivateKey.Validate(); err != nil { + return err + } + // symmetric keys don't have public keys + if key.PublicKey != nil { + if err := key.PublicKey.Validate(); err != nil { + return err + } + } + + for _, op := range key.PrivateKey.KeyOps() { + if op == jwk.KeyOpSign { + signingKeys = append(signingKeys, key.PrivateKey) + break + } + } + } + + switch { + case len(signingKeys) == 0: + return fmt.Errorf("no signing key detected") + case len(signingKeys) > 1: + return fmt.Errorf("multiple signing keys detected, only 1 signing key is supported") + } + + return nil +} + +func GetSigningJwk(config *JWTConfiguration) (jwk.Key, error) { + for _, key := range config.Keys { + for _, op := range key.PrivateKey.KeyOps() { + // the private JWK with key_ops "sign" should be used as the signing key + if op == jwk.KeyOpSign { + return key.PrivateKey, nil + } + } + } + return nil, fmt.Errorf("no signing key found") +} + +func GetSigningKey(k jwk.Key) (any, error) { + var key any + if err := k.Raw(&key); err != nil { + return nil, err + } + return key, nil +} + +func GetSigningAlg(k jwk.Key) jwt.SigningMethod { + if k == nil { + return jwt.SigningMethodHS256 + } + + switch (k).Algorithm().String() { + case "RS256": + return jwt.SigningMethodRS256 + case "RS512": + return jwt.SigningMethodRS512 + case "ES256": + return jwt.SigningMethodES256 + case "ES512": + return jwt.SigningMethodES512 + case "EdDSA": + return jwt.SigningMethodEdDSA + } + + // return HS256 to preserve existing behaviour + return jwt.SigningMethodHS256 +} + +func FindPublicKeyByKid(kid string, config *JWTConfiguration) (any, error) { + if k, ok := config.Keys[kid]; ok { + key, err := GetSigningKey(k.PublicKey) + if err != nil { + return nil, err + } + return key, nil + } + if kid == config.KeyID { + return []byte(config.Secret), nil + } + return nil, fmt.Errorf("invalid kid: %s", kid) +} diff --git a/auth_v2.169.0/internal/conf/jwk_test.go b/auth_v2.169.0/internal/conf/jwk_test.go new file mode 100644 index 0000000..275c0eb --- /dev/null +++ b/auth_v2.169.0/internal/conf/jwk_test.go @@ -0,0 +1,81 @@ +package conf + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDecode(t *testing.T) { + // array of JWKs containing 4 keys + gotrueJwtKeys := `[{"kty":"oct","k":"9Sj51i2YvfY85NJZFD6rAl9fKDxSKjFgW6W6ZXOJLnU","kid":"f90202bc-413a-4db3-8e04-b70a02a65669","key_ops":["verify"],"alg":"HS256"},{"kty":"RSA","n":"4slQjr-XoU6I1KXFWOeeJi387RIUxjhyzXX3GUVNb75a0SPKoGShlJEbpvuXqkDLGDweLcIZy-01nqgjSzMY_tUO3L78MxVfIVn7MByJ4_zbrVf5rjKeAk9EEMl6pb8nKJGArph9sOwL68LLioNySt_WNo_hMfuxUuVkRagh5gLjYoQ4odkULQrgwlMcXxXNnvg0aYURUr2SDmncHNuZQ3adebRlI164mUZPPWui2fg72R7c9qhVaAEzbdG-JAuC3zn5iL4zZk-8pOwZkM7Qb_2lrcXwdTl_Qz6fMdAHz_3rggac5oeKkdvO2x7_XiUwGxIBYSghxg5BBxcyqd6WrQ","e":"AQAB","d":"FjJo7uH4aUoktO8kHhbHbY_KSdQpHDjKyc7yTS_0DWYgUfdozzubJfRDF42vI-KsXssF-NoB0wJf0uP0L8ip6G326XPuoMQRTMgcaF8j6swTwsapSOEagr7BzcECx1zpc2-ojhwbLHSvRutWDzPJkbrUccF8vRC6BsiAUG4Hapiumbot7JtJGwU8ZUhxico7_OEJ_MtkRrHByXgrOMnzNLrmViI9rzvtWOhVc8sNDzLogDDi01AP0j6WeBhbOpaZ_1BMLQ9IeeN5Iiy-7Qj-q4-8kBXIPXpYaKMFnDTmhB0GAVUFimF6ojhZNAJvV81VMHPjrEmmps0_qBfIlKAB","p":"9G7wBpiSJHAl-w47AWvW60v_hye50lte4Ep2P3KeRyinzgxtEMivzldoqirwdoyPCJWwU7nNsv7AjdXVoHFy3fJvJeV5mhArxb2zA36OS_Tr3CQXtB3OO-RFwVcG7AGO7XvA54PK28siXY2VvkG2Xn_ZrbVebJnHQprn7ddUIIE","q":"7YSaG2E_M9XpgUJ0izwKdfGew6Hz5utPUdwMWjqr81BjtLkUtQ3tGYWs2tdaRYUTK4mNFyR2MjLYnMK-F37rue4LSKitmEu2N6RD9TwzcqwiEL_vuQTC985iJ0hzUC58LcbhYtTLU3KqZXXUqaeBXEwQAWxK1NRf6rQRhOGk4C0","dp":"fOV-sfAdpI7FaW3RCp3euGYh0B6lXW4goXyKxUq8w2FrtOY2iH_zDP0u1tyP-BNENr-91Fo5V__BxfeAa7XsWqo4zuVdaDJhG24d3Wg6L2ebaOXsUrV0Hrg6SFs-hzMYpBI69FEsQ3idO65P2GJdXBX51T-6WsWMwmTCo44GR4E","dq":"O2DrJe0p38ualLYIbMaV1uaQyleyoggxzEU20VfZpPpz8rpScvEIVVkV3Z_48WhTYo8AtshmxCXyAT6uRzFzvQfFymRhAbHr2_01ABoMwp5F5eoWBCsskscFwsxaB7GXWdpefla0figscTED-WXm8SwS1Eg-bParBAIAXzgKAAE","qi":"Cezqw8ECfMmwnRXJuiG2A93lzhixHxXISvGC-qbWaRmCfetheSviZlM0_KxF6dsvrw_aNfIPa8rv1TbN-5F04v_RU1CD79QuluzXWLkZVhPXorkK5e8sUi_odzAJXOwHKQzal5ndInl4XYctDHQr8jXcFW5Un65FhPwdAC6-aek","kid":"74b1a36b-4b39-467f-976b-acc7ec600a6d","key_ops":["verify"],"alg":"RS256"},{"kty":"EC","x":"GwbnH57MUhgL14dJfayyzuI6o2_mB_Pm8xIuauHXtQs","y":"cYqN0VAcv0BC9wrg3vNgHlKhGP8ZEedUC2A8jXpaGwA","crv":"P-256","d":"4STEXq7W4UY0piCGPueMaQqAAZ5jVRjjA_b1Hq7YgmM","kid":"fa3ffc99-4635-4b19-b5c0-6d6a8d30c4eb","key_ops":["sign","verify"],"alg":"ES256"},{"crv":"Ed25519","d":"T179kXSOJHE8CNbqaI2HNdG8r3YbSoKYxNRSzTkpEcY","x":"iDYagELzmD4z6uaW7eAZLuQ9fiUlnLqtrh7AfNbiNiI","kty":"OKP","kid":"b1176272-46e4-4226-b0bd-12eef4fd7367","key_ops":["verify"],"alg":"EdDSA"}]` + var decoder JwtKeysDecoder + require.NoError(t, decoder.Decode(gotrueJwtKeys)) + require.Len(t, decoder, 4) + + for kid, key := range decoder { + require.NotEmpty(t, kid) + require.NotNil(t, key.PrivateKey) + require.NotNil(t, key.PublicKey) + require.NotEmpty(t, key.PublicKey.KeyOps(), "missing key_ops claim") + } +} + +func TestJWTConfiguration(t *testing.T) { + // array of JWKs containing 4 keys + gotrueJwtKeys := `[{"kty":"oct","k":"9Sj51i2YvfY85NJZFD6rAl9fKDxSKjFgW6W6ZXOJLnU","kid":"f90202bc-413a-4db3-8e04-b70a02a65669","key_ops":["verify"],"alg":"HS256"},{"kty":"RSA","n":"4slQjr-XoU6I1KXFWOeeJi387RIUxjhyzXX3GUVNb75a0SPKoGShlJEbpvuXqkDLGDweLcIZy-01nqgjSzMY_tUO3L78MxVfIVn7MByJ4_zbrVf5rjKeAk9EEMl6pb8nKJGArph9sOwL68LLioNySt_WNo_hMfuxUuVkRagh5gLjYoQ4odkULQrgwlMcXxXNnvg0aYURUr2SDmncHNuZQ3adebRlI164mUZPPWui2fg72R7c9qhVaAEzbdG-JAuC3zn5iL4zZk-8pOwZkM7Qb_2lrcXwdTl_Qz6fMdAHz_3rggac5oeKkdvO2x7_XiUwGxIBYSghxg5BBxcyqd6WrQ","e":"AQAB","d":"FjJo7uH4aUoktO8kHhbHbY_KSdQpHDjKyc7yTS_0DWYgUfdozzubJfRDF42vI-KsXssF-NoB0wJf0uP0L8ip6G326XPuoMQRTMgcaF8j6swTwsapSOEagr7BzcECx1zpc2-ojhwbLHSvRutWDzPJkbrUccF8vRC6BsiAUG4Hapiumbot7JtJGwU8ZUhxico7_OEJ_MtkRrHByXgrOMnzNLrmViI9rzvtWOhVc8sNDzLogDDi01AP0j6WeBhbOpaZ_1BMLQ9IeeN5Iiy-7Qj-q4-8kBXIPXpYaKMFnDTmhB0GAVUFimF6ojhZNAJvV81VMHPjrEmmps0_qBfIlKAB","p":"9G7wBpiSJHAl-w47AWvW60v_hye50lte4Ep2P3KeRyinzgxtEMivzldoqirwdoyPCJWwU7nNsv7AjdXVoHFy3fJvJeV5mhArxb2zA36OS_Tr3CQXtB3OO-RFwVcG7AGO7XvA54PK28siXY2VvkG2Xn_ZrbVebJnHQprn7ddUIIE","q":"7YSaG2E_M9XpgUJ0izwKdfGew6Hz5utPUdwMWjqr81BjtLkUtQ3tGYWs2tdaRYUTK4mNFyR2MjLYnMK-F37rue4LSKitmEu2N6RD9TwzcqwiEL_vuQTC985iJ0hzUC58LcbhYtTLU3KqZXXUqaeBXEwQAWxK1NRf6rQRhOGk4C0","dp":"fOV-sfAdpI7FaW3RCp3euGYh0B6lXW4goXyKxUq8w2FrtOY2iH_zDP0u1tyP-BNENr-91Fo5V__BxfeAa7XsWqo4zuVdaDJhG24d3Wg6L2ebaOXsUrV0Hrg6SFs-hzMYpBI69FEsQ3idO65P2GJdXBX51T-6WsWMwmTCo44GR4E","dq":"O2DrJe0p38ualLYIbMaV1uaQyleyoggxzEU20VfZpPpz8rpScvEIVVkV3Z_48WhTYo8AtshmxCXyAT6uRzFzvQfFymRhAbHr2_01ABoMwp5F5eoWBCsskscFwsxaB7GXWdpefla0figscTED-WXm8SwS1Eg-bParBAIAXzgKAAE","qi":"Cezqw8ECfMmwnRXJuiG2A93lzhixHxXISvGC-qbWaRmCfetheSviZlM0_KxF6dsvrw_aNfIPa8rv1TbN-5F04v_RU1CD79QuluzXWLkZVhPXorkK5e8sUi_odzAJXOwHKQzal5ndInl4XYctDHQr8jXcFW5Un65FhPwdAC6-aek","kid":"74b1a36b-4b39-467f-976b-acc7ec600a6d","key_ops":["verify"],"alg":"RS256"},{"kty":"EC","x":"GwbnH57MUhgL14dJfayyzuI6o2_mB_Pm8xIuauHXtQs","y":"cYqN0VAcv0BC9wrg3vNgHlKhGP8ZEedUC2A8jXpaGwA","crv":"P-256","d":"4STEXq7W4UY0piCGPueMaQqAAZ5jVRjjA_b1Hq7YgmM","kid":"fa3ffc99-4635-4b19-b5c0-6d6a8d30c4eb","key_ops":["sign","verify"],"alg":"ES256"},{"crv":"Ed25519","d":"T179kXSOJHE8CNbqaI2HNdG8r3YbSoKYxNRSzTkpEcY","x":"iDYagELzmD4z6uaW7eAZLuQ9fiUlnLqtrh7AfNbiNiI","kty":"OKP","kid":"b1176272-46e4-4226-b0bd-12eef4fd7367","key_ops":["verify"],"alg":"EdDSA"}]` + var decoder JwtKeysDecoder + require.NoError(t, decoder.Decode(gotrueJwtKeys)) + require.Len(t, decoder, 4) + + cases := []struct { + desc string + config JWTConfiguration + expectedLength int + }{ + { + desc: "GOTRUE_JWT_KEYS is nil", + config: JWTConfiguration{ + Secret: "testsecret", + KeyID: "testkeyid", + }, + expectedLength: 1, + }, + { + desc: "GOTRUE_JWT_KEYS is an empty map", + config: JWTConfiguration{ + Secret: "testsecret", + KeyID: "testkeyid", + Keys: JwtKeysDecoder{}, + }, + expectedLength: 1, + }, + { + desc: "Prefer GOTRUE_JWT_KEYS over GOTRUE_JWT_SECRET", + config: JWTConfiguration{ + Secret: "testsecret", + KeyID: "testkeyid", + Keys: decoder, + }, + expectedLength: 4, + }, + } + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + globalConfig := GlobalConfiguration{ + JWT: c.config, + } + require.NoError(t, globalConfig.ApplyDefaults()) + require.NotEmpty(t, globalConfig.JWT.Keys) + require.Len(t, globalConfig.JWT.Keys, c.expectedLength) + for _, key := range globalConfig.JWT.Keys { + // public keys should contain these require claims + require.NotNil(t, key.PublicKey.Algorithm()) + require.NotNil(t, key.PublicKey.KeyID()) + require.NotNil(t, key.PublicKey.KeyOps()) + require.Equal(t, "sig", key.PublicKey.KeyUsage()) + } + }) + } +} diff --git a/auth_v2.169.0/internal/conf/logging.go b/auth_v2.169.0/internal/conf/logging.go new file mode 100644 index 0000000..d079006 --- /dev/null +++ b/auth_v2.169.0/internal/conf/logging.go @@ -0,0 +1,11 @@ +package conf + +type LoggingConfig struct { + Level string `mapstructure:"log_level" json:"log_level"` + File string `mapstructure:"log_file" json:"log_file"` + DisableColors bool `mapstructure:"disable_colors" split_words:"true" json:"disable_colors"` + QuoteEmptyFields bool `mapstructure:"quote_empty_fields" split_words:"true" json:"quote_empty_fields"` + TSFormat string `mapstructure:"ts_format" json:"ts_format"` + Fields map[string]interface{} `mapstructure:"fields" json:"fields"` + SQL string `mapstructure:"sql" json:"sql"` +} diff --git a/auth_v2.169.0/internal/conf/metrics.go b/auth_v2.169.0/internal/conf/metrics.go new file mode 100644 index 0000000..ac6f7ec --- /dev/null +++ b/auth_v2.169.0/internal/conf/metrics.go @@ -0,0 +1,26 @@ +package conf + +type MetricsExporter = string + +const ( + Prometheus MetricsExporter = "prometheus" + OpenTelemetryMetrics MetricsExporter = "opentelemetry" +) + +type MetricsConfig struct { + Enabled bool + + Exporter MetricsExporter `default:"opentelemetry"` + + // ExporterProtocol is the OTEL_EXPORTER_OTLP_PROTOCOL env variable, + // only available when exporter is opentelemetry. See: + // https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/protocol/exporter.md + ExporterProtocol string `default:"http/protobuf" envconfig:"OTEL_EXPORTER_OTLP_PROTOCOL"` + + PrometheusListenHost string `default:"0.0.0.0" envconfig:"OTEL_EXPORTER_PROMETHEUS_HOST"` + PrometheusListenPort string `default:"9100" envconfig:"OTEL_EXPORTER_PROMETHEUS_PORT"` +} + +func (mc MetricsConfig) Validate() error { + return nil +} diff --git a/auth_v2.169.0/internal/conf/profiler.go b/auth_v2.169.0/internal/conf/profiler.go new file mode 100644 index 0000000..41752bf --- /dev/null +++ b/auth_v2.169.0/internal/conf/profiler.go @@ -0,0 +1,7 @@ +package conf + +type ProfilerConfig struct { + Enabled bool `default:"false"` + Host string `default:"localhost"` + Port string `default:"9998"` +} diff --git a/auth_v2.169.0/internal/conf/rate.go b/auth_v2.169.0/internal/conf/rate.go new file mode 100644 index 0000000..059ed65 --- /dev/null +++ b/auth_v2.169.0/internal/conf/rate.go @@ -0,0 +1,65 @@ +package conf + +import ( + "fmt" + "strconv" + "strings" + "time" +) + +const defaultOverTime = time.Hour + +const ( + BurstRateType = "burst" + IntervalRateType = "interval" +) + +type Rate struct { + Events float64 `json:"events,omitempty"` + OverTime time.Duration `json:"over_time,omitempty"` + typ string +} + +func (r *Rate) GetRateType() string { + if r.typ == "" { + return IntervalRateType + } + return r.typ +} + +// Decode is used by envconfig to parse the env-config string to a Rate value. +func (r *Rate) Decode(value string) error { + if f, err := strconv.ParseFloat(value, 64); err == nil { + r.typ = IntervalRateType + r.Events = f + r.OverTime = defaultOverTime + return nil + } + parts := strings.Split(value, "/") + if len(parts) != 2 { + return fmt.Errorf("rate: value does not match rate syntax %q", value) + } + + // 52 because the uint needs to fit in a float64 + e, err := strconv.ParseUint(parts[0], 10, 52) + if err != nil { + return fmt.Errorf("rate: events part of rate value %q failed to parse as uint64: %w", value, err) + } + + d, err := time.ParseDuration(parts[1]) + if err != nil { + return fmt.Errorf("rate: over-time part of rate value %q failed to parse as duration: %w", value, err) + } + + r.typ = BurstRateType + r.Events = float64(e) + r.OverTime = d + return nil +} + +func (r *Rate) String() string { + if r.OverTime == 0 { + return fmt.Sprintf("%f", r.Events) + } + return fmt.Sprintf("%d/%s", uint64(r.Events), r.OverTime.String()) +} diff --git a/auth_v2.169.0/internal/conf/rate_test.go b/auth_v2.169.0/internal/conf/rate_test.go new file mode 100644 index 0000000..378deda --- /dev/null +++ b/auth_v2.169.0/internal/conf/rate_test.go @@ -0,0 +1,68 @@ +package conf + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRateDecode(t *testing.T) { + cases := []struct { + str string + exp Rate + err string + }{ + {str: "1800", + exp: Rate{Events: 1800, OverTime: time.Hour, typ: IntervalRateType}}, + {str: "1800.0", + exp: Rate{Events: 1800, OverTime: time.Hour, typ: IntervalRateType}}, + {str: "3600/1h", + exp: Rate{Events: 3600, OverTime: time.Hour, typ: BurstRateType}}, + {str: "3600/1h0m0s", + exp: Rate{Events: 3600, OverTime: time.Hour, typ: BurstRateType}}, + {str: "100/24h", + exp: Rate{Events: 100, OverTime: time.Hour * 24, typ: BurstRateType}}, + {str: "", exp: Rate{}, + err: `rate: value does not match`}, + {str: "1h", exp: Rate{}, + err: `rate: value does not match`}, + {str: "/", exp: Rate{}, + err: `rate: events part of rate value`}, + {str: "/1h", exp: Rate{}, + err: `rate: events part of rate value`}, + {str: "3600.0/1h", exp: Rate{}, + err: `rate: events part of rate value "3600.0/1h" failed to parse`}, + {str: "100/", exp: Rate{}, + err: `rate: over-time part of rate value`}, + {str: "100/1", exp: Rate{}, + err: `rate: over-time part of rate value`}, + + // zero events + {str: "0/1h", + exp: Rate{Events: 0, OverTime: time.Hour, typ: BurstRateType}}, + {str: "0/24h", + exp: Rate{Events: 0, OverTime: time.Hour * 24, typ: BurstRateType}}, + } + for idx, tc := range cases { + var r Rate + err := r.Decode(tc.str) + require.Equal(t, tc.exp, r) // verify don't mutate r on errr + t.Logf("tc #%v - duration str %v", idx, tc.str) + if tc.err != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.err) + continue + } + require.NoError(t, err) + require.Equal(t, tc.exp, r) + require.Equal(t, tc.exp.typ, r.GetRateType()) + } + + // GetRateType() zero value + require.Equal(t, IntervalRateType, (&Rate{}).GetRateType()) + + // String() + require.Equal(t, "0.000000", (&Rate{}).String()) + require.Equal(t, "100/1h0m0s", (&Rate{Events: 100, OverTime: time.Hour}).String()) +} diff --git a/auth_v2.169.0/internal/conf/saml.go b/auth_v2.169.0/internal/conf/saml.go new file mode 100644 index 0000000..66a820c --- /dev/null +++ b/auth_v2.169.0/internal/conf/saml.go @@ -0,0 +1,136 @@ +package conf + +import ( + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "errors" + "fmt" + "math/big" + "net" + "net/url" + "time" +) + +// SAMLConfiguration holds configuration for native SAML support. +type SAMLConfiguration struct { + Enabled bool `json:"enabled"` + PrivateKey string `json:"-" split_words:"true"` + AllowEncryptedAssertions bool `json:"allow_encrypted_assertions" split_words:"true"` + RelayStateValidityPeriod time.Duration `json:"relay_state_validity_period" split_words:"true"` + + RSAPrivateKey *rsa.PrivateKey `json:"-"` + RSAPublicKey *rsa.PublicKey `json:"-"` + Certificate *x509.Certificate `json:"-"` + + ExternalURL string `json:"external_url,omitempty" split_words:"true"` + + RateLimitAssertion float64 `default:"15" split_words:"true"` +} + +func (c *SAMLConfiguration) Validate() error { + if c.Enabled { + bytes, err := base64.StdEncoding.DecodeString(c.PrivateKey) + if err != nil { + return errors.New("SAML private key not in standard Base64 format") + } + + privateKey, err := x509.ParsePKCS1PrivateKey(bytes) + if err != nil { + return errors.New("SAML private key not in PKCS#1 format") + } + + err = privateKey.Validate() + if err != nil { + return errors.New("SAML private key is not valid") + } + + if privateKey.E != 0x10001 { + return errors.New("SAML private key should use the 65537 (0x10001) RSA public exponent") + } + + if privateKey.N.BitLen() < 2048 { + return errors.New("SAML private key must be at least RSA 2048") + } + + if c.RelayStateValidityPeriod < 0 { + return errors.New("SAML RelayState validity period should be a positive duration") + } + + if c.ExternalURL != "" { + _, err := url.ParseRequestURI(c.ExternalURL) + if err != nil { + return err + } + } + } + + return nil +} + +// PopulateFields fills the configuration details based off the provided +// parameters. +func (c *SAMLConfiguration) PopulateFields(externalURL string) error { + // errors are intentionally ignored since they should have been handled + // within #Validate() + bytes, _ := base64.StdEncoding.DecodeString(c.PrivateKey) + privateKey, _ := x509.ParsePKCS1PrivateKey(bytes) + + c.RSAPrivateKey = privateKey + c.RSAPublicKey = privateKey.Public().(*rsa.PublicKey) + + parsedURL, err := url.ParseRequestURI(externalURL) + if err != nil { + return fmt.Errorf("saml: unable to parse external URL for SAML, check API_EXTERNAL_URL: %w", err) + } + + host := "" + host, _, err = net.SplitHostPort(parsedURL.Host) + if err != nil { + host = parsedURL.Host + } + + // SAML does not care much about the contents of the certificate, it + // only uses it as a vessel for the public key; therefore we set these + // fixed values. + // Please avoid modifying or adding new values to this template as they + // will change the exposed SAML certificate, requiring users of + // GoTrue to re-establish a connection between their Identity Provider + // and their running GoTrue instances. + certTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(0), + IsCA: false, + DNSNames: []string{ + "_samlsp." + host, + }, + KeyUsage: x509.KeyUsageDigitalSignature, + NotBefore: time.UnixMilli(0).UTC(), + NotAfter: time.UnixMilli(0).UTC().AddDate(200, 0, 0), + Subject: pkix.Name{ + CommonName: "SAML 2.0 Certificate for " + host, + }, + } + + if c.AllowEncryptedAssertions { + certTemplate.KeyUsage = certTemplate.KeyUsage | x509.KeyUsageDataEncipherment + } + + certDer, err := x509.CreateCertificate(nil, certTemplate, certTemplate, c.RSAPublicKey, c.RSAPrivateKey) + if err != nil { + return err + } + + cert, err := x509.ParseCertificate(certDer) + if err != nil { + return err + } + + c.Certificate = cert + + if c.RelayStateValidityPeriod == 0 { + c.RelayStateValidityPeriod = 2 * time.Minute + } + + return nil +} diff --git a/auth_v2.169.0/internal/conf/saml_test.go b/auth_v2.169.0/internal/conf/saml_test.go new file mode 100644 index 0000000..e8de37e --- /dev/null +++ b/auth_v2.169.0/internal/conf/saml_test.go @@ -0,0 +1,101 @@ +package conf + +import ( + tst "testing" + + "encoding/base64" + + "github.com/stretchr/testify/require" +) + +func TestSAMLConfigurationValidate(t *tst.T) { + invalidExamples := []*SAMLConfiguration{ + { + Enabled: true, + PrivateKey: "", + }, + { + Enabled: true, + PrivateKey: "InvalidBase64!", + }, + { + Enabled: true, + PrivateKey: base64.StdEncoding.EncodeToString([]byte("not PKCS#1")), + }, + { + // RSA 1024 key + Enabled: true, + PrivateKey: "MIICXQIBAAKBgQDFa3SgzWZpcoONv3Iq3FxNieks2u2TmykxxxeggI9aNpHpuCzwGQO8wqXGVvFNlkE3GSPcz7rklzfyj577Z47lfWdBP1OAefralA3tS2mafqpZ32JwDynX4as+xauLVdP4iOR96b3L2eOb6rDpr4wBJuNqO533xsjcbNPINEDkSwIDAQABAoGASggBtEtSHDjVHFKufWQlOO5+glOWw8Nrrz75nTaYizvre7mVIHRA8ogLolT4KCAwVHkY+bTsYMxULqGs/JnY+40suHECYQ2u76PTQlvJnhJANGtCxuV4lSK6B8QBJhjGExsnAOwMMKz0p5kVftx2GA+/Rz2De7DR9keNECjcAAECQQDtr5cdkEdnIffvi782843EvX/g8615AKZeUYVUl0gVXujjpIVZXDtytPHINvIW1Z2mOm2rlJukwiKYYJ8IjsxlAkEA1KGbJ9EI6AOUcnpy7FYdGkbINTDngCqVOoaddlHS+1SaofpYXZPueXXIqIG3viksxmq/Q0IY6+JRkGo/RpGq7wJARD+BAqok9oYYbR4RX7P7ZxyKlYsiqnX3T2nVAP8XYZuI/6SD7a7AGyW9ryGnzcq0o8BvMS9QqbRcvqgvwgNOyQJBAL2ZVMaOSIjKGGZz9WHz74Nstj1n3CWW0vYa7vGASMc/S5s/pefbbvvzIPfQo0z3XiuXJ/ELUTmU1vIVK1L7tRUCQQCsuE7xckZ8H/523jdWWy9zZVP1z4c5dVLDR5RY+YQNForgb6kkSv4Gzn/FRUOxqn2MEWJLla31D4EuS+XKuwZR", + }, + { + // RSA 2048 with 0x11 as public exponent + Enabled: true, + PrivateKey: "MIIEowIBAAKCAQEAyMvTanPoiorCpIQCl70qXF34FIPOkKaInr1vw+3/0nik5CDUo761E02uTrK4/8JXr5NLGmy/fQmagNsBOdKewciRB3xxs+sPNncptG4rpCBjxSJdVl+mYZaw2kdvFY7TvNTlr7qG1Q0kV/3lBgpMlyM9OqBrjuG0UUzB5hlg08KLNflkQAkoJGWNVWULi2VceP3I3QsH9uNUQkgaM9Z6rl0BaRAkobHTTvquAqqj1AlNmSh24rrIbV4hYcNnesIpG4+LDd8XfpOwTp+jUl8akF6xcRBJjiPDJGN9ety29DcCxjo2i0b+TWYU+Pex08uOeOdulsgecbIVxLUEgRHcFQIBEQKCAQBefgkjCV5fUFuYtpfO75t2ws8Ytn9TITE7pHDUrDwmz1ynlvqnaM2uuyTZvYQ8HzhSn6rfQjv+mxuH7pcqRP9qQEQ/whdjuekKkm36jjKnlsWJ8g3OSyEe3YBmuDRGYVSVGOSO7l2Rb5ih4OQ/E+fOpyvfWoz38b5EYFs/GwBjpgJG+9cdCLYKOax8WDifWkjHdrogAlE8do/QF6RZoSvhAbRkpuxYActmKU8rIORrq8dLidSjBG2aoRH+RCN4ONZ3R4iHbYF2zWfqDFdSIX64kChaOZVhtTyTnF7/1v4VF3UwByEs8hTSckFH2jW6T7RZoatpgsv5zx/roRPDBWNRAoGBAPGphQwX9GF56XVmqD9rQMD9a0FuGCNGgiFkt2OUTFKmr4rTNVE8uJqEXjQwbybTnF2iJ1ApL1zNHg0cpnpMt7ZpcWG4Bu2UsXlwBL/ZwY4Spk6tHNdTsg/wuoWRSIGNanNS6CI5EUA4cxGNUt0G+dF4LaMHZuIAU7avs+kwDMzHAoGBANS1nS8KYkPUwYlmgVPNhMDTtjvq7fgP5UFDXnlhE6rJidc/+B0p9WiRhLGWlZebn+h2fELfIgK3qc4IzCHOkar0pic2D3bNbboNQKnqFl81hg0EORTK0JJ5/K4J61l5+rZtQu3Ss1HVwDiy9SKg6F3CQj9PK0r+hjtAStFSmZxDAoGBAMcEEzciyUE3OLsJP0NJRGKylJA8jFlJH99D4lIBqEQQzMyt76xQH45O5Cr6teO9U5hna6ttNhAwcxnbW+w/LeGEAwUuI9K2sEXjx60NrnUATLlDRO2QOElc1ddolhBWV6pERrLFlbxquR2DcWq6c2E1yzr3CW7TF8OfwVagCoqFAoGBAK8sJxeuMs5y+bxyiJ9d9NsItDFYD0TBy9tkqCe5W32W6fyPCJB86Df/XjflbCKAKVYHOSgDDPMt1yIlXNCL/326arbhOeld4eSDYm3P1jBKMijWTSAujaXN3yXqDRyCkjvhgmmAV3CR6Zga5/5mZQHrRZ2MfgGGUG0HxSTanJ7NAoGBAOhZBGtFsBdtEawvCh4Z8NaMC2nU+Ru9hEsZSy6rZQrPvja9aBUk5QUdh04TYtu8PzQ1EghZy71rtwDAvxXWJ1mWcZn0kD06tZKudmZpMVXCp3SFah6DDUCFSmQ2U60yh6XOzpS2+Z97Ngi02UFph8sSQA6Dl/lmaf4bfQHCYc5Z", + }, + } + + for i, example := range invalidExamples { + err := example.Validate() + require.Error(t, err, "Invalid example %d was regarded as valid", i) + } + + validExamples := []*SAMLConfiguration{ + { + Enabled: false, + }, + { + // RSA 2048 + Enabled: true, + PrivateKey: "MIIEowIBAAKCAQEAsBuxTUWFrfy0qYXaqNSeVWcJOd6TQ4+4b/3N4p/58r1d/kMU+K+BGR+tF0GKHGYngTF6puvNDff2wgW3dp3LUSMjxOhC3sK0uL90vd+IR6v1EDDGLyQNo6EjP/x5Gp/PcL2s6hZb8iLBEq4FksPnEhWqf9Nsmgf1YPJV4AvaaWe3oBFo9zJobSs3etTVitc3qEH2DpgYFtrCKhMWv5qoZtZTyZRE3LU3rvInDgYw6HDGF1G4y4Fvah6VpRmTdyMR81r1tCLmGvk61QJp7i4HteazQ6Raqh2EZ1sH/UfEp8mrwYRaRdgLDQ/Q6/YlO8NTQwzp6YwwAybhMBnOrABLCQIDAQABAoIBADqobq0DPByQsIhKmmNjtn1RvYP1++0kANXknuAeUv2kT5tyMpkGtCRvJZM6dEszR3NDzMuufPVrI1jK2Kn8sw0KfE6I4kUaa2Gh+7uGqfjdcNn8tPZctuJKuNgGOzxAALNXqjGqUuPa6Z5UMm0JLX0blFfRTzoa7oNlFG9040H6CRjJQQGfYyPS8xeo+RUR009sK/222E5jz6ThIiCrOU/ZGm5Ws9y3AAIASqJd9QPy7qxKoFZ1qKZ/cDaf1txCKq9VBXH6ypZoU1dQibhyLCIJ3tYapBtV4p8V12oHhITXb6Vbo1P9bQSVz+2rQ0nJkjdXX/N4aHE01ecbu8MpMxUCgYEA5P4ZCAdpkTaOSJi7GyL4AcZ5MN26eifFnRO/tbmw07f6vi//vdqzC9T7kxmZ8e1OvhX5OMGNb3nsXm78WgS2EVLTkaTInG6XhlOeYj9BHAQZDBr7rcAxrVQxVgaGDiZpYun++kXw+39iq3gxuYuC9mM0AQze3SjTRIM9WWXJSqMCgYEAxODfXcWMk2P/WfjE3u+8fhjc3cvqyWSyThEZC9YzpN59dL73SE7BRkMDyZO19fFvVO9mKsRfsTio0ceC5XQOO6hUxAm4gAEvMpeapQgXTxIxF5FAQ0vGmBMxT+xg7lX8HTTJX/UCttKo3BdIJQeTf8bKVzJCoLFh8Rcv5qI6umMCgYAEuj44DTcfuVmcpBKQz9sA5mEQIjO8W9/Xi1XU4Z2F8XFqxcDo4X/6yY3cDpZACV8ry3ZWtqA94e2AUZhCH4DGwMf/ZMCDgkD8k/NcIeQtOORvfIsfni0oX+mY1g+kcSSR1zTdY95CwvF9isC0DO5KOegT8XkUZchezLrSgqhyMwKBgQCvS0mWRH6V/UMu6MDhfrNl0t1U3mt+RZo8yBx03ZO+CBvMBvxF9VlBJgoJQOuSwBVQmpdtHMvXD4vAvNNfWaYSmB5hLgaIcoWDlliq+DlIvfnX8gw13xJD9VLCxsTHcOe5WXazaYOxJIAU9uXVkplR+73NRYLtcQKzluGfiHKh4QKBgFpPtOqcAbkMsV+1qPYvvvX7E4+l52Odb4tbxGBYV8tzCqMRETqMPVxFWwsj+EQ8lyAu15rCRH7DKHVK5zL6JvIZEjt0tptKqSL2o3ovS6y3DmD6t+YpvjKME7a+vunOoJWe9pWl3wZmodfyZMpAdDLvDGhPR7Jlhun41tbMMaQF", + }, + { + // RSA 3072 + Enabled: true, + PrivateKey: "MIIG4wIBAAKCAYEApYkvDaXJEDsELSVosc0sKFnoPeJai8sOu8di5ffGVJRr7mJi+VQjM0d2KeOIllVk2IV58M33Jz2Rx61NYPLu0N9fZqPwbgYn+FNz1L1xgslUL6gyaQnCEKtH5mRqPEBOPvAygq/fZ46eBMs3GSS6NWp/XF/iPaFc1mBDAZFvXev4XV7O6iuqz5mx3rQbkIhMjQxP+IOYWMS4TqueLJWgFUbij0FepJfOE+AlmfBa7xIOyE+g5t3vRB8XwzxRPsljlfgZXstxO1r1NS3DPiUj3kGYy7em5Yb+icIA6xzy0MiwU5RcBSwtVc+M/Yk2tMY6a9z1UX2M5Zr/ih3w0CbW6KDYplqgwwDZv2f+ynIqldn7SjVo3V6fWFu+KtRkofWWkTGjaU2DTpxrxUJEnEo6zXfBSejAjGGAJyKjX74uATlOu/LQEjd5umQpWYvtvP1UkbjHYgITtoTytb3uU7Q7W/YdtNUcaE377QHZF+E+XTCCCw00bCvpDciW+w0JSkRfAgMBAAECggGAR0jCKIBiC0k+zSo04YxXHbFJ34xgLZ7t41NDdYCzuayIpglcUb43wlddvUAsi4COgudH0bkAW7eZ1YD9t2gmC3CFpq+mU9r2z2swkEZcYVPNmxA1VSJMnd0Eg2Ruky+mAlhxh/GwpOm3hpz0RzGXtnT8D42C4cNhNTgS4tP8P1fkhmDTfef8EJZBEIRC8oSfYoYQ0hXpPyDHtakV3mE4pLD303T1CrAMoGaACsCEiDsgfoY75e9gn9c75mlNG1qhhJYxD3Sv1o9lQd3Q1A71sga/E+yIlUcPP4fDaA8DdeH+FHwL9xgQPd18gsrbPdbsg8JMLmjblaz8BB1MvJMwj+b3Ey2idD8CVIq5Ql97TebyMxZp3ZYjLq/R2ay+MpE9Vjgih096Hg+kCPMPi3Q9AmVJX8kN8+2zm2EeDoI/YnJFzmBcmaOuSBEGYdrRk5RCYfZMa1jvpoNUGbWzoX4gRfC7Gr+alaCWa9ot2c+ChWZQlpbKaMYMLU/VEd7gsf/BAoHBANJsSdIxrTUWpdm27SJlq5ylKVokS5ftrxk2gM8M16jLYE7SwPrhdxphWGH8/TMz6/Jyd+CuSfwAzyy325szlFlZVpxv8qu1vWROBaaaq1Tg8cqYC2s+hUTJLevcmiBHFu+7tiYNmMqkNIfj9/FN1zvfPVwqurtB5WXGjI4qhf5SyJgtj1GiM/s9Ae86LiRZhovcEEwf0LddGpMrUEDrWOV9D95sOMA00rsJXOfOg78Ms7Nq/h9w6cnD5x4jUJTMzwKBwQDJY/TMNVa1V8ci+pOMB6iTyi3azVC6ZiCXinCQS0oLebY1GmyWLv9A+/+9Wg/h4p4OdlZSA2/9f6+6njAcxI1wfzHVC3vgF7EDs9YUeAmXWBA171uPHbfimTd21utLkcyJ/WdO4OmKP7ZIK8UWyXE98N5NQV9NRX0sm6CJemwChcoJ8/7lsuYa4nJVUXtAkAMoj7e0nOoWn1IzyolmIXSTrBPiLWh68172tr3ciR6uGN3Yba6szkFTeaBDfNQvk3ECgcEAy07XkKBwwv+L5SxKOFbVlfc6Wh8Bbty2tnyjvemhoTRHbEFTNdOMeU+ezqZamhNLoKgazVp4n2TEx2cpZu5SInYgKew8Is3pHLYJ3axJaCwjUmTPe6Ifr5NVrDMsM42cSqsqVeADRZ+cJcQMtvhHwlByf8/FNdJ4a3qIKYBKkKy5pdc3R1+aK+AJM3QaSwK47f8FPBftWI07dQB/fQonjSvlnjkgKA2hohdszYgKYRhLtEnnGMfHCywd7U+ftvWfAoHAcxfq+SGqkiy+I+FsnWRrFTtAhYE9F6nyCmkV94Dvqis+1I5rbFEjk7Hw7/geh4uJpN5AatKIGCn29gIdoPM7mgU3J3hOrT0c7u7B9CS95n5vlUNb4iirxJanugUNp7yFVn85oTyse1P6CrjpBCLP0wRrJ1+q5XBHH005rBgIzlBDrPiCvidFlivAB75vX/BtvaqU5GWg6pjW0752U6XfB94Z5vLoeQvJQ9ogG39Jx1lyv5O/dgbSErC5xJf8c8whAoHAYdxLfZcDN2IEUgg/czE9jDSJ2EwOCE9HpCntthHAvceK3+PFfpCKwOLynqF8urhdeM510QJK2ETLvzpgMBgSh/klxeBYv8BCL8BuPwyPciAFmPE1Stx7C1+JBF2fayYkCSK9w85INLAJYKTDk9gE8O6l0bXA8tuq3F0tRTwMBcyEpMOehKFamoPcU6cnNa2HC+MyTOfXSBeNZ2VciFYf5rh3YrwoUYbQJtDXxFvoX0Ba+zyneNG0j3epXZuR2lyK", + }, + { + // RSA 4096 + Enabled: true, + PrivateKey: "MIIJKgIBAAKCAgEA2cNnNX4Be3jOKTr7lxIWxWfFKtwFqbWs9CZS7gDNXUtBlGuV1+FswPvSRKWEmwsBQikBfScowk4hL/JFgN8V25PijOk7eTPmw3tHuUhoil7GkJCMKhtrYwGbvINk1pK5mfI+V8GR3l52S779fg8nwktOtr99sLgfxUdxwxFY5hE5lo5P19QPClAA89SjQ3c/FlXy8R56/qf4u+Fuvd7Ecq7nQGeovsiSpBxY2gn4KL2LdkkyZmEQVgXzXjDGOOhF7M6eKim5MCsUqgHjCCkK7Gw9HNbd4oHNE5ucWRYjG1IpEYbYmep/9+wXgwQorYFKUT0NXrUv5H3VLQpsDyWDRZJ+wXGbwV2bRh2Z5bbAJVTxF8NaO8XujVZLIe+UJ8kUWj+n3hxwil9UU9yExR6M9TZBfHTKOVWcn1CquT85ppI0dtvlu3ToBwjjcd1wWLK8rLhmEwafC142bSL2kXLc6p7YrhTBN7PBPodQ2lLMg8xbw4cNspsMAPAPfrisqEYUGAs/EUScgcsSfmyzKNcdZlUx6UkMhz2F8sKPi4I4oIugxQiCa7LuSjmfrM6msIkrV+sj06zUYmAZzN+cf7rRlGFLNt1cKqqukjhbo9RL54XZQssT5GkHuVT6neyQBJX9EwtmZtXBTI78WTUabQhBcEBbxWbn5VodxDPXmfAiumsCAwEAAQKCAgEAnU1ux5BPJ877NYNa7CTv+AdewPgQyyfmWLM6YpyHvKW5KKqSolA/jCQcHuRlps3LSexvG+Xmpn1jscvTcyUzF9t64ok0Ifhg8MKj6+6nPZT64MDZzyzhZLJrukA73lg85Dy91gyI/1XDJDJB0QbHlK1rnc0z0S0gHhTe06c7TW4R6HTCrkiL2Moz9e6bRQfltY++n3iCJmRV4/oTUeqSg7leaQK4PaCLdSrY8CAVd/B7xqVXV+czssA3rcmT1tXKdSZH0HM1R9tG4Qvd4S4sqt4BQ0zfGVjkOA7HYP8BuyGdcwCyhHSFniSYU1b0v2jOs2Jjvw8pGmffTtrhdguGB60rMocKyfXvRxjJmIXZae6W8ZCwz76rKr8igXZUXvK3LqhGfm5fDpvWQlX8ugnwWOmowJqToS/fVKwhjFjsPONRbRZh7MTebRjx9ErpQycTm0SiUrUA/WE8Na1JeelTjxThCuy1VjIOtYVk4eYGP6REQV+nYGGuD7ruR+dpD4UR3/2DsPLik8X+YUFMjGCr+LjzybDj8Ux+a/u/eKD3rIe45PooJzGR/s+RCcwtAIue29+C+2uj3lAypEIqRGd2k0RgEw8Cj43Omc3Pyf+M3IbKfpE82OGSPp/rgHIfJSwGuOWH09yxCjyqY9H/wtxea6qOpeuk/g4ipaTp/QvZikkCggEBAPeowAf5hz16Oreb4D1c9MoHdofO/Kv8D8Kig9YhwDGck8uG4TbdBMSre5mOGSykZBA566/YHf49Pdi5Eq3L5oEdNKJrxn35WlMHY2pdOCrhWfqe+Z3Fg6qlhQTFU0blFAwy6NUixHP7xsLyAdpjkSxdsQzOaHUMII8w0gD+/AqSq3c/sC9AF+CeiZQV0P53eseNVfxfv8f1aDH7JcywG4P6Xe9pdHoNW93u2j2HQcrLidOtsT5s8iXj2YO3d4YZg/I20dViC7+DrG1ep+rfiuYY5VS1jKVqTknzKHlP7OHOaYJhDPAffnNFBWj4Th11NKxigpx3ogXO9jVyCGXWwD0CggEBAOEY5hvGEufmWx821Oln3OoaUv4MBSa0RMuuBZyx7mx18gfidjB9exC6sFduTwdFjnM8PUDBnOo2ZM7ThcCbTJ4Qi7LB5gDAXUbJqJk7o+lKrfXcpYdksoXWHmWAT7RE1v9nbXle1KHKIaaga/I8hVtSfeTizb8y+dDP3T3H8tVByvneAE0LnDVmr1VhFppKnzWl5vTY2Y+6XGIWmrCuWS1+zf+dx32zJ2ZOfT1Wwk20igC79RzH0sDHSv7DNyUn9u/9LtjIIrDtWch9+5Xkq0uZQAqM0Jw/QUYqarJSNNVhREmwWk+B6sJaQUN26YyTHiOpfFu1RUwHyyg58L8yJ8cCggEBALqSqnhXh4bM+kcwavJPgSpiDO2rBbcbIVRj0iYTLxMw/jap2ijWwKzY8zhvUI/NGIUQ3XmPuqi5wknuwx+jKHfEZM6nmtV0cJN0UXTj3ViQhJTGBw7Qqax5HYjGj0Itebjm8Xj/xDgMSWS7pKG9uLRPsP4Q0ai8BhtZkBun/ICKlho0JKq0Akj5pnOlK9lIcXq8AzcpevVM774XkhZt5Yy7pOCj9VetkLPVKRyJNQtt4ttRUuHQeWwKBuev459mwXxLyDCUuH0C2Xdbg+zxk1ZdEweJ7fb/6xLS2H7rs205b0sFihWr5Ds6mCTISzDuB0yGuhbeGXV+wQTqb2EpM5ECggEBAMBFsGiQ7J1BWxxyjbNBkKY3DiUKx2ukGA+S+iA6rFng9XherG4HARPtI6vLAZ5If8FW90tVFl/JTpqMe3dmMC/kGi/7CCgkKIjKwEUDeKNRsv6MFqhsD0Ha/+Pbkjl9g9ht1EkUA7SfH9dguFQV9iNndzoHsY9cT59ZrrWTEY2vwV1lkAQ/opLKv4HCiLgKfawppfoHMO9gVIFEpaW9h1chNXzenQR1/3WYHcpDTX1qdWbjJiALX65jjV/ICFaoqHmeXmG1skxGsaZcVoZW6SqOIPHiDl8oeO0iVjkzlwWdK+N1y+6WHp0c0xp5fE0jbV8w6pS7ZhHnplUaCNaIVQkCggEAUcQ0VhA1FlWT/mcbTp3v3ojCkZH8fkbNQoL5nsxb+72VDGB0YYZBXyBCApY2E/3jfAH8nG0ALAVS1s983USx17Z2+Z+Yg13Gt1dglV0JC2vMS6j83G0UxsKdcFyafbgJl+hrvBAuOoqPLd5r5F6VnDZeDDsJ3Y6ZTmcsbf8EZkUSxT80oKBLwXi06dfnEz7nYUxvqk54QG3xN1VJAQoKaJ9sH9pbAPdA0GxRx8DIWBd3UhMFJbdIplfGlkk9kf+E1k6Z2SaRB8QQHpvdgsdQ6YXPV+0ejhiGytX9DMSmjZe3dC4C7ZdaCL+kSxdFRgIo2KAcJVdpsqbw/hclfNY7cQ==", + }, + } + + for i, example := range validExamples { + err := example.Validate() + require.NoError(t, err, "Valid example %d was regarded as invalid", i) + } +} + +func TestSAMLConfigurationPopulateFields(t *tst.T) { + c := &SAMLConfiguration{ + Enabled: true, + PrivateKey: "MIIEowIBAAKCAQEAt7dS8iM5MsQ+1mVkNpoaUnL8BCdxSrSx8jsSnvqN/GIJ4ipqbdrTgLpFVklVTqfaa5CykGVEV577l6AWkpkm2p7SvSkCQglmyAMMjY9glmztytAnfBpm+cQ6ZVTHC4XKlUG1aJigEuXPcZUU3FiBHWEuV2huYy2bLOtIY1v9N0i2v61QCdG+SM/Yb5t86KzApRl7VyHqquge6vvRuchfF0msv/2LW32hwxg3Gt4zkAF0SJqCCcfAPZ9pQwmbdUhoX16dRFU98nyIvuR8LH/wONZe/YyywFFHDEwkFa4XEzjCEm+AD+xvK7eEu55w21xB8JKMLEBy8uRuI3bIEG4pawIDAQABAoIBADw4IT4xgYw8e4R3U7P6K2qfOjB6ZU5hkHqgFmh6JJR35ll2IdDEi9OEOzofa5EOwC/GDGH8b7xw5nM7DGsdPHko2lca3BydTE1/glvchYKJTiDOvkKVvO9d/O4+Lch/IHpwQXB5pu7K2YaXoXDgqeHhevk3yAdGabj9norDGmtGIeU/x1hialKbw6L080CdbxpjeAsM/w+G/VtwvyOKYFBYxBflRW+sS8UeclVqKRAvaXKd1JGleWzH3hFZyFI54x5LyyjPI1JyVXRjNbf8xcS6eRaN849grL1+wBxEs/lQFn4JLhAcNi912iJ3lhxvkNleXZw7B7JAM8x4wUbK7zECgYEA6SYmu3YH8XuLUfT8MMCp+ETjPkNMOJGQmTXOkW6zuXP3J8iCPIxtuz09cGIro+yJU23yPUzOVCDZMmnMWBmkoTKAFoFL9TX0Eyqn/t1MD77i3NdkMp16yI5fwOO6yX1bZgLiG00W2E5/IGgNfTtEafU/mre95JBnTgxS3sAvz8UCgYEAybjfBVt+1X0vSVAGKYHI9wtzoSx3dIGE8G5LIchPTdNDZ0ke0QCRffhyCGKy6bPos0P2z5nLgWSePBPZQowpwZiQVXdWE05ID641E2zGULdYL1yVHDt6tVTpSzTAy89BiS1G8HvgpQyaBTmvmF11Fyd/YbrDxEIHN+qQdDkM928CgYEA4lJ4ksz21QF6sqpADQtZc3lbplspqFgVp8RFq4Nsz3+00lefpSskcff2phuGBXBdtjEqTzs5pwzkCj4NcRAjcZ9WG4KTu4sOTXTA83TamwZPrtUfnMqmH/2lEdd+wI0BpjryRlJE9ODuIwUe4wwfU0QQ5B2tJizPO0JXR4gEYYkCgYBzqidm4QGm1DLq7JG79wkObmiMv/x2t1VMr1ExO7QNQdfiP1EGMjc6bdyk5kMEMf5527yHaP4BYXpBpHfs6oV+1kXcW6LlSvuS0iboznQgECDmd0WgfJJtqxRh5QuvUVWYnHeSqNU0jjc6S8tdqCjdb+5gUUCzJdERxNOzcIr4zQKBgAqcBQwlWy0PdlZ06JhJUYlwX1pOU8mWPz9LIF0wrSm9LEtAl37zZJaD3uscvk/fCixAGHOktkDGVO7aUYIAlX9iD49huGkeRTn9tz7Wanw6am04Xj0y7H1oPPV7k5nJ4s9AOWq/gkZEhrRIis2anAczsx1YHSjq/M05+AbuRzvs", + } + + err := c.PopulateFields("https://projectref.supabase.co") + require.NoError(t, err) + + require.NotNil(t, c.RSAPrivateKey) + require.NotNil(t, c.RSAPublicKey) + require.NotNil(t, c.Certificate) +} + +func TestSAMLConfigurationDeterministicCertificate(t *tst.T) { + a := &SAMLConfiguration{ + Enabled: true, + PrivateKey: "MIIEowIBAAKCAQEAt7dS8iM5MsQ+1mVkNpoaUnL8BCdxSrSx8jsSnvqN/GIJ4ipqbdrTgLpFVklVTqfaa5CykGVEV577l6AWkpkm2p7SvSkCQglmyAMMjY9glmztytAnfBpm+cQ6ZVTHC4XKlUG1aJigEuXPcZUU3FiBHWEuV2huYy2bLOtIY1v9N0i2v61QCdG+SM/Yb5t86KzApRl7VyHqquge6vvRuchfF0msv/2LW32hwxg3Gt4zkAF0SJqCCcfAPZ9pQwmbdUhoX16dRFU98nyIvuR8LH/wONZe/YyywFFHDEwkFa4XEzjCEm+AD+xvK7eEu55w21xB8JKMLEBy8uRuI3bIEG4pawIDAQABAoIBADw4IT4xgYw8e4R3U7P6K2qfOjB6ZU5hkHqgFmh6JJR35ll2IdDEi9OEOzofa5EOwC/GDGH8b7xw5nM7DGsdPHko2lca3BydTE1/glvchYKJTiDOvkKVvO9d/O4+Lch/IHpwQXB5pu7K2YaXoXDgqeHhevk3yAdGabj9norDGmtGIeU/x1hialKbw6L080CdbxpjeAsM/w+G/VtwvyOKYFBYxBflRW+sS8UeclVqKRAvaXKd1JGleWzH3hFZyFI54x5LyyjPI1JyVXRjNbf8xcS6eRaN849grL1+wBxEs/lQFn4JLhAcNi912iJ3lhxvkNleXZw7B7JAM8x4wUbK7zECgYEA6SYmu3YH8XuLUfT8MMCp+ETjPkNMOJGQmTXOkW6zuXP3J8iCPIxtuz09cGIro+yJU23yPUzOVCDZMmnMWBmkoTKAFoFL9TX0Eyqn/t1MD77i3NdkMp16yI5fwOO6yX1bZgLiG00W2E5/IGgNfTtEafU/mre95JBnTgxS3sAvz8UCgYEAybjfBVt+1X0vSVAGKYHI9wtzoSx3dIGE8G5LIchPTdNDZ0ke0QCRffhyCGKy6bPos0P2z5nLgWSePBPZQowpwZiQVXdWE05ID641E2zGULdYL1yVHDt6tVTpSzTAy89BiS1G8HvgpQyaBTmvmF11Fyd/YbrDxEIHN+qQdDkM928CgYEA4lJ4ksz21QF6sqpADQtZc3lbplspqFgVp8RFq4Nsz3+00lefpSskcff2phuGBXBdtjEqTzs5pwzkCj4NcRAjcZ9WG4KTu4sOTXTA83TamwZPrtUfnMqmH/2lEdd+wI0BpjryRlJE9ODuIwUe4wwfU0QQ5B2tJizPO0JXR4gEYYkCgYBzqidm4QGm1DLq7JG79wkObmiMv/x2t1VMr1ExO7QNQdfiP1EGMjc6bdyk5kMEMf5527yHaP4BYXpBpHfs6oV+1kXcW6LlSvuS0iboznQgECDmd0WgfJJtqxRh5QuvUVWYnHeSqNU0jjc6S8tdqCjdb+5gUUCzJdERxNOzcIr4zQKBgAqcBQwlWy0PdlZ06JhJUYlwX1pOU8mWPz9LIF0wrSm9LEtAl37zZJaD3uscvk/fCixAGHOktkDGVO7aUYIAlX9iD49huGkeRTn9tz7Wanw6am04Xj0y7H1oPPV7k5nJ4s9AOWq/gkZEhrRIis2anAczsx1YHSjq/M05+AbuRzvs", + } + + b := &SAMLConfiguration{ + Enabled: a.Enabled, + PrivateKey: a.PrivateKey, + } + + err := a.PopulateFields("https://projectref.supabase.co") + require.NoError(t, err) + + err = b.PopulateFields("https://projectref.supabase.co") + require.NoError(t, err) + + require.Equal(t, a.Certificate.Raw, b.Certificate.Raw, "Certificate generation should be deterministic") +} diff --git a/auth_v2.169.0/internal/conf/tracing.go b/auth_v2.169.0/internal/conf/tracing.go new file mode 100644 index 0000000..9a1d9be --- /dev/null +++ b/auth_v2.169.0/internal/conf/tracing.go @@ -0,0 +1,33 @@ +package conf + +type TracingExporter = string + +const ( + OpenTelemetryTracing TracingExporter = "opentelemetry" +) + +type TracingConfig struct { + Enabled bool + Exporter TracingExporter `default:"opentelemetry"` + + // ExporterProtocol is the OTEL_EXPORTER_OTLP_PROTOCOL env variable, + // only available when exporter is opentelemetry. See: + // https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/protocol/exporter.md + ExporterProtocol string `default:"http/protobuf" envconfig:"OTEL_EXPORTER_OTLP_PROTOCOL"` + + // Host is the host of the OpenTracing collector. + Host string + + // Port is the port of the OpenTracing collector. + Port string + + // ServiceName is the service name to use with OpenTracing. + ServiceName string `default:"gotrue" split_words:"true"` + + // Tags are the tags to associate with OpenTracing. + Tags map[string]string +} + +func (tc *TracingConfig) Validate() error { + return nil +} diff --git a/auth_v2.169.0/internal/crypto/crypto.go b/auth_v2.169.0/internal/crypto/crypto.go new file mode 100644 index 0000000..6fc2b71 --- /dev/null +++ b/auth_v2.169.0/internal/crypto/crypto.go @@ -0,0 +1,159 @@ +package crypto + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "math" + "math/big" + "strconv" + "strings" + + "golang.org/x/crypto/hkdf" +) + +// SecureToken creates a new random token +func SecureToken() string { + b := make([]byte, 16) + must(io.ReadFull(rand.Reader, b)) + + return base64.RawURLEncoding.EncodeToString(b) +} + +// GenerateOtp generates a random n digit otp +func GenerateOtp(digits int) string { + upper := math.Pow10(digits) + val := must(rand.Int(rand.Reader, big.NewInt(int64(upper)))) + + // adds a variable zero-padding to the left to ensure otp is uniformly random + expr := "%0" + strconv.Itoa(digits) + "v" + otp := fmt.Sprintf(expr, val.String()) + + return otp +} +func GenerateTokenHash(emailOrPhone, otp string) string { + return fmt.Sprintf("%x", sha256.Sum224([]byte(emailOrPhone+otp))) +} + +// Generated a random secure integer from [0, max[ +func secureRandomInt(max int) int { + randomInt := must(rand.Int(rand.Reader, big.NewInt(int64(max)))) + return int(randomInt.Int64()) +} + +type EncryptedString struct { + KeyID string `json:"key_id"` + Algorithm string `json:"alg"` + Data []byte `json:"data"` + Nonce []byte `json:"nonce,omitempty"` +} + +func (es *EncryptedString) IsValid() bool { + return es.KeyID != "" && len(es.Data) > 0 && len(es.Nonce) > 0 && es.Algorithm == "aes-gcm-hkdf" +} + +// ShouldReEncrypt tells you if the value encrypted needs to be encrypted again with a newer key. +func (es *EncryptedString) ShouldReEncrypt(encryptionKeyID string) bool { + return es.KeyID != encryptionKeyID +} + +func (es *EncryptedString) Decrypt(id string, decryptionKeys map[string]string) ([]byte, error) { + decryptionKey := decryptionKeys[es.KeyID] + + if decryptionKey == "" { + return nil, fmt.Errorf("crypto: decryption key with name %q does not exist", es.KeyID) + } + + key, err := deriveSymmetricKey(id, es.KeyID, decryptionKey) + if err != nil { + return nil, err + } + + block := must(aes.NewCipher(key)) + cipher := must(cipher.NewGCM(block)) + + decrypted, err := cipher.Open(nil, es.Nonce, es.Data, nil) // #nosec G407 + if err != nil { + return nil, err + } + + return decrypted, nil +} + +func ParseEncryptedString(str string) *EncryptedString { + if !strings.HasPrefix(str, "{") { + return nil + } + + var es EncryptedString + + if err := json.Unmarshal([]byte(str), &es); err != nil { + return nil + } + + if !es.IsValid() { + return nil + } + + return &es +} + +func (es *EncryptedString) String() string { + out := must(json.Marshal(es)) + + return string(out) +} + +func deriveSymmetricKey(id, keyID, keyBase64URL string) ([]byte, error) { + hkdfKey, err := base64.RawURLEncoding.DecodeString(keyBase64URL) + if err != nil { + return nil, err + } + + if len(hkdfKey) != 256/8 { + return nil, fmt.Errorf("crypto: key with ID %q is not 256 bits", keyID) + } + + // Since we use AES-GCM here, the same symmetric key *must not be used + // more than* 2^32 times. But, that's not that much. Suppose a system + // with 100 million users, then a user can only change their password + // 42 times. To prevent this, the actual symmetric key is derived by + // using HKDF using the encryption key and the "ID" of the object + // containing the encryption string. Ideally this ID is a UUID. This + // has the added benefit that the encrypted string is bound to that + // specific object, and can't accidentally be "moved" to other objects + // without changing their ID to the original one. + + keyReader := hkdf.New(sha256.New, hkdfKey, nil, []byte(id)) + key := make([]byte, 256/8) + + must(io.ReadFull(keyReader, key)) + + return key, nil +} + +func NewEncryptedString(id string, data []byte, keyID string, keyBase64URL string) (*EncryptedString, error) { + key, err := deriveSymmetricKey(id, keyID, keyBase64URL) + if err != nil { + return nil, err + } + + block := must(aes.NewCipher(key)) + cipher := must(cipher.NewGCM(block)) + + es := EncryptedString{ + KeyID: keyID, + Algorithm: "aes-gcm-hkdf", + Nonce: make([]byte, 12), + } + + must(io.ReadFull(rand.Reader, es.Nonce)) + es.Data = cipher.Seal(nil, es.Nonce, data, nil) // #nosec G407 + + return &es, nil +} diff --git a/auth_v2.169.0/internal/crypto/crypto_test.go b/auth_v2.169.0/internal/crypto/crypto_test.go new file mode 100644 index 0000000..f1c8e67 --- /dev/null +++ b/auth_v2.169.0/internal/crypto/crypto_test.go @@ -0,0 +1,108 @@ +package crypto + +import ( + "testing" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/assert" +) + +func TestEncryptedStringPositive(t *testing.T) { + id := uuid.Must(uuid.NewV4()).String() + + es, err := NewEncryptedString(id, []byte("data"), "key-id", "pwFoiPyybQMqNmYVN0gUnpbfpGQV2sDv9vp0ZAxi_Y4") + assert.NoError(t, err) + + assert.Equal(t, es.KeyID, "key-id") + assert.Equal(t, es.Algorithm, "aes-gcm-hkdf") + assert.Len(t, es.Data, 20) + assert.Len(t, es.Nonce, 12) + + dec := ParseEncryptedString(es.String()) + + assert.NotNil(t, dec) + assert.Equal(t, dec.Algorithm, "aes-gcm-hkdf") + assert.Len(t, dec.Data, 20) + assert.Len(t, dec.Nonce, 12) + + decrypted, err := dec.Decrypt(id, map[string]string{ + "key-id": "pwFoiPyybQMqNmYVN0gUnpbfpGQV2sDv9vp0ZAxi_Y4", + }) + + assert.NoError(t, err) + assert.Equal(t, []byte("data"), decrypted) +} + +func TestParseEncryptedStringNegative(t *testing.T) { + negativeExamples := []string{ + "not-an-encrypted-string", + // not json + "{{", + // not parsable json + `{"key_id":1}`, + `{"alg":1}`, + `{"data":"!!!"}`, + `{"nonce":"!!!"}`, + // not valid + `{}`, + `{"key_id":"key_id"}`, + `{"key_id":"key_id","alg":"different","data":"AQAB=","nonce":"AQAB="}`, + } + + for _, example := range negativeExamples { + assert.Nil(t, ParseEncryptedString(example)) + } +} + +func TestEncryptedStringDecryptNegative(t *testing.T) { + id := uuid.Must(uuid.NewV4()).String() + + // short key + _, err := NewEncryptedString(id, []byte("data"), "key-id", "short_key") + assert.Error(t, err) + + // not base64 + _, err = NewEncryptedString(id, []byte("data"), "key-id", "!!!") + assert.Error(t, err) + + es, err := NewEncryptedString(id, []byte("data"), "key-id", "pwFoiPyybQMqNmYVN0gUnpbfpGQV2sDv9vp0ZAxi_Y4") + assert.NoError(t, err) + + dec := ParseEncryptedString(es.String()) + assert.NotNil(t, dec) + + _, err = dec.Decrypt(id, map[string]string{ + // empty map + }) + assert.Error(t, err) + + // short key + _, err = dec.Decrypt(id, map[string]string{ + "key-id": "AQAB", + }) + assert.Error(t, err) + + // key not base64 + _, err = dec.Decrypt(id, map[string]string{ + "key-id": "!!!", + }) + assert.Error(t, err) + + // bad key + _, err = dec.Decrypt(id, map[string]string{ + "key-id": "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", + }) + assert.Error(t, err) + + // bad tag for AEAD failure + dec.Data[len(dec.Data)-1] += 1 + + _, err = dec.Decrypt(id, map[string]string{ + "key-id": "pwFoiPyybQMqNmYVN0gUnpbfpGQV2sDv9vp0ZAxi_Y4", + }) + assert.Error(t, err) +} + +func TestSecureToken(t *testing.T) { + assert.Equal(t, len(SecureToken()), 22) +} diff --git a/auth_v2.169.0/internal/crypto/password.go b/auth_v2.169.0/internal/crypto/password.go new file mode 100644 index 0000000..7cf4607 --- /dev/null +++ b/auth_v2.169.0/internal/crypto/password.go @@ -0,0 +1,418 @@ +package crypto + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/subtle" + "encoding/base64" + "errors" + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/supabase/auth/internal/observability" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + + "golang.org/x/crypto/argon2" + "golang.org/x/crypto/bcrypt" + "golang.org/x/crypto/scrypt" +) + +type HashCost = int + +const ( + // DefaultHashCost represents the default + // hashing cost for any hashing algorithm. + DefaultHashCost HashCost = iota + + // QuickHashCosts represents the quickest + // hashing cost for any hashing algorithm, + // useful for tests only. + QuickHashCost HashCost = iota + + Argon2Prefix = "$argon2" + FirebaseScryptPrefix = "$fbscrypt" + FirebaseScryptKeyLen = 32 // Firebase uses AES-256 which requires 32 byte keys: https://pkg.go.dev/golang.org/x/crypto/scrypt#Key +) + +// PasswordHashCost is the current pasword hashing cost +// for all new hashes generated with +// GenerateHashFromPassword. +var PasswordHashCost = DefaultHashCost + +var ( + generateFromPasswordSubmittedCounter = observability.ObtainMetricCounter("gotrue_generate_from_password_submitted", "Number of submitted GenerateFromPassword hashing attempts") + generateFromPasswordCompletedCounter = observability.ObtainMetricCounter("gotrue_generate_from_password_completed", "Number of completed GenerateFromPassword hashing attempts") +) + +var ( + compareHashAndPasswordSubmittedCounter = observability.ObtainMetricCounter("gotrue_compare_hash_and_password_submitted", "Number of submitted CompareHashAndPassword hashing attempts") + compareHashAndPasswordCompletedCounter = observability.ObtainMetricCounter("gotrue_compare_hash_and_password_completed", "Number of completed CompareHashAndPassword hashing attempts") +) + +var ErrArgon2MismatchedHashAndPassword = errors.New("crypto: argon2 hash and password mismatch") +var ErrScryptMismatchedHashAndPassword = errors.New("crypto: fbscrypt hash and password mismatch") + +// argon2HashRegexp https://github.com/P-H-C/phc-string-format/blob/master/phc-sf-spec.md#argon2-encoding +var argon2HashRegexp = regexp.MustCompile("^[$](?Pargon2(d|i|id))[$]v=(?P(16|19))[$]m=(?P[0-9]+),t=(?P[0-9]+),p=(?P

[0-9]+)(,keyid=(?P[^,$]+))?(,data=(?P[^$]+))?[$](?P[^$]*)[$](?P.*)$") +var fbscryptHashRegexp = regexp.MustCompile(`^\$fbscrypt\$v=(?P[0-9]+),n=(?P[0-9]+),r=(?P[0-9]+),p=(?P

[0-9]+)(?:,ss=(?P[^,]+))?(?:,sk=(?P[^$]+))?\$(?P[^$]+)\$(?P.+)$`) + +type Argon2HashInput struct { + alg string + v string + memory uint64 + time uint64 + threads uint64 + keyid string + data string + salt []byte + rawHash []byte +} + +type FirebaseScryptHashInput struct { + v string + memory uint64 + rounds uint64 + threads uint64 + saltSeparator []byte + signerKey []byte + salt []byte + rawHash []byte +} + +// See: https://github.com/firebase/scrypt for implementation +func ParseFirebaseScryptHash(hash string) (*FirebaseScryptHashInput, error) { + submatch := fbscryptHashRegexp.FindStringSubmatchIndex(hash) + if submatch == nil { + return nil, errors.New("crypto: incorrect scrypt hash format") + } + + v := string(fbscryptHashRegexp.ExpandString(nil, "$v", hash, submatch)) + n := string(fbscryptHashRegexp.ExpandString(nil, "$n", hash, submatch)) + r := string(fbscryptHashRegexp.ExpandString(nil, "$r", hash, submatch)) + p := string(fbscryptHashRegexp.ExpandString(nil, "$p", hash, submatch)) + ss := string(fbscryptHashRegexp.ExpandString(nil, "$ss", hash, submatch)) + sk := string(fbscryptHashRegexp.ExpandString(nil, "$sk", hash, submatch)) + saltB64 := string(fbscryptHashRegexp.ExpandString(nil, "$salt", hash, submatch)) + hashB64 := string(fbscryptHashRegexp.ExpandString(nil, "$hash", hash, submatch)) + + if v != "1" { + return nil, fmt.Errorf("crypto: Firebase scrypt hash uses unsupported version %q only version 1 is supported", v) + } + memoryPower, err := strconv.ParseUint(n, 10, 32) + if err != nil { + return nil, fmt.Errorf("crypto: Firebase scrypt hash has invalid n parameter %q %w", n, err) + } + if memoryPower == 0 { + return nil, fmt.Errorf("crypto: Firebase scrypt hash has invalid n=0") + } + rounds, err := strconv.ParseUint(r, 10, 64) + if err != nil { + return nil, fmt.Errorf("crypto: Firebase scrypt hash has invalid r parameter %q: %w", r, err) + } + if rounds == 0 { + return nil, fmt.Errorf("crypto: Firebase scrypt hash has invalid r=0") + } + + threads, err := strconv.ParseUint(p, 10, 8) + if err != nil { + return nil, fmt.Errorf("crypto: Firebase scrypt hash has invalid p parameter %q %w", p, err) + } + if threads == 0 { + return nil, fmt.Errorf("crypto: Firebase scrypt hash has invalid p=0") + } + + rawHash, err := base64.StdEncoding.DecodeString(hashB64) + if err != nil { + return nil, fmt.Errorf("crypto: Firebase scrypt hash has invalid base64 in the hash section %w", err) + } + + salt, err := base64.StdEncoding.DecodeString(saltB64) + if err != nil { + return nil, fmt.Errorf("crypto: Firebase scrypt salt has invalid base64 in the hash section %w", err) + } + + var saltSeparator, signerKey []byte + if signerKey, err = base64.StdEncoding.DecodeString(sk); err != nil { + return nil, err + } + if saltSeparator, err = base64.StdEncoding.DecodeString(ss); err != nil { + return nil, err + } + + input := &FirebaseScryptHashInput{ + v: v, + memory: uint64(1) << memoryPower, + rounds: rounds, + threads: threads, + salt: salt, + rawHash: rawHash, + saltSeparator: saltSeparator, + signerKey: signerKey, + } + + return input, nil +} + +func ParseArgon2Hash(hash string) (*Argon2HashInput, error) { + submatch := argon2HashRegexp.FindStringSubmatchIndex(hash) + if submatch == nil { + return nil, errors.New("crypto: incorrect argon2 hash format") + } + + alg := string(argon2HashRegexp.ExpandString(nil, "$alg", hash, submatch)) + v := string(argon2HashRegexp.ExpandString(nil, "$v", hash, submatch)) + m := string(argon2HashRegexp.ExpandString(nil, "$m", hash, submatch)) + t := string(argon2HashRegexp.ExpandString(nil, "$t", hash, submatch)) + p := string(argon2HashRegexp.ExpandString(nil, "$p", hash, submatch)) + keyid := string(argon2HashRegexp.ExpandString(nil, "$keyid", hash, submatch)) + data := string(argon2HashRegexp.ExpandString(nil, "$data", hash, submatch)) + saltB64 := string(argon2HashRegexp.ExpandString(nil, "$salt", hash, submatch)) + hashB64 := string(argon2HashRegexp.ExpandString(nil, "$hash", hash, submatch)) + + if alg != "argon2i" && alg != "argon2id" { + return nil, fmt.Errorf("crypto: argon2 hash uses unsupported algorithm %q only argon2i and argon2id supported", alg) + } + + if v != "19" { + return nil, fmt.Errorf("crypto: argon2 hash uses unsupported version %q only %d is supported", v, argon2.Version) + } + + if data != "" { + return nil, fmt.Errorf("crypto: argon2 hashes with the data parameter not supported") + } + + if keyid != "" { + return nil, fmt.Errorf("crypto: argon2 hashes with the keyid parameter not supported") + } + + memory, err := strconv.ParseUint(m, 10, 32) + if err != nil { + return nil, fmt.Errorf("crypto: argon2 hash has invalid m parameter %q %w", m, err) + } + + time, err := strconv.ParseUint(t, 10, 32) + if err != nil { + return nil, fmt.Errorf("crypto: argon2 hash has invalid t parameter %q %w", t, err) + } + + threads, err := strconv.ParseUint(p, 10, 8) + if err != nil { + return nil, fmt.Errorf("crypto: argon2 hash has invalid p parameter %q %w", p, err) + } + + rawHash, err := base64.RawStdEncoding.DecodeString(hashB64) + if err != nil { + return nil, fmt.Errorf("crypto: argon2 hash has invalid base64 in the hash section %w", err) + } + if len(rawHash) == 0 { + return nil, errors.New("crypto: argon2 hash is empty") + } + + salt, err := base64.RawStdEncoding.DecodeString(saltB64) + if err != nil { + return nil, fmt.Errorf("crypto: argon2 hash has invalid base64 in the salt section %w", err) + } + if len(salt) == 0 { + return nil, errors.New("crypto: argon2 salt is empty") + } + + input := Argon2HashInput{ + alg: alg, + v: v, + memory: memory, + time: time, + threads: threads, + keyid: keyid, + data: data, + salt: salt, + rawHash: rawHash, + } + + return &input, nil +} + +func compareHashAndPasswordArgon2(ctx context.Context, hash, password string) error { + input, err := ParseArgon2Hash(hash) + if err != nil { + return err + } + + attributes := []attribute.KeyValue{ + attribute.String("alg", input.alg), + attribute.String("v", input.v), + attribute.Int64("m", int64(input.memory)), + attribute.Int64("t", int64(input.time)), + attribute.Int("p", int(input.threads)), + attribute.Int("len", len(input.rawHash)), + } // #nosec G115 + + var match bool + var derivedKey []byte + compareHashAndPasswordSubmittedCounter.Add(ctx, 1, metric.WithAttributes(attributes...)) + defer func() { + attributes = append(attributes, attribute.Bool( + "match", + match, + )) + + compareHashAndPasswordCompletedCounter.Add(ctx, 1, metric.WithAttributes(attributes...)) + }() + + switch input.alg { + case "argon2i": + derivedKey = argon2.Key([]byte(password), input.salt, uint32(input.time), uint32(input.memory), uint8(input.threads), uint32(len(input.rawHash))) // #nosec G115 + + case "argon2id": + derivedKey = argon2.IDKey([]byte(password), input.salt, uint32(input.time), uint32(input.memory), uint8(input.threads), uint32(len(input.rawHash))) // #nosec G115 + } + + match = subtle.ConstantTimeCompare(derivedKey, input.rawHash) == 1 + + if !match { + return ErrArgon2MismatchedHashAndPassword + } + + return nil +} + +func compareHashAndPasswordFirebaseScrypt(ctx context.Context, hash, password string) error { + input, err := ParseFirebaseScryptHash(hash) + if err != nil { + return err + } + + attributes := []attribute.KeyValue{ + attribute.String("v", input.v), + attribute.Int64("n", int64(input.memory)), + attribute.Int64("r", int64(input.rounds)), + attribute.Int("p", int(input.threads)), + attribute.Int("len", len(input.rawHash)), + } // #nosec G115 + + var match bool + compareHashAndPasswordSubmittedCounter.Add(ctx, 1, metric.WithAttributes(attributes...)) + defer func() { + attributes = append(attributes, attribute.Bool("match", match)) + compareHashAndPasswordCompletedCounter.Add(ctx, 1, metric.WithAttributes(attributes...)) + }() + + derivedKey := firebaseScrypt([]byte(password), input.salt, input.signerKey, input.saltSeparator, input.memory, input.rounds, input.threads) + + match = subtle.ConstantTimeCompare(derivedKey, input.rawHash) == 1 + if !match { + return ErrScryptMismatchedHashAndPassword + } + + return nil +} + +func firebaseScrypt(password, salt, signerKey, saltSeparator []byte, memCost, rounds, p uint64) []byte { + ck := must(scrypt.Key(password, append(salt, saltSeparator...), int(memCost), int(rounds), int(p), FirebaseScryptKeyLen)) // #nosec G115 + block := must(aes.NewCipher(ck)) + + cipherText := make([]byte, aes.BlockSize+len(signerKey)) + + // #nosec G407 -- Firebase scrypt requires deterministic IV for consistent results. See: JaakkoL/firebase-scrypt-python@master/firebasescrypt/firebasescrypt.py#L58 + stream := cipher.NewCTR(block, cipherText[:aes.BlockSize]) + stream.XORKeyStream(cipherText[aes.BlockSize:], signerKey) + + return cipherText[aes.BlockSize:] +} + +// CompareHashAndPassword compares the hash and +// password, returns nil if equal otherwise an error. Context can be used to +// cancel the hashing if the algorithm supports it. +func CompareHashAndPassword(ctx context.Context, hash, password string) error { + if strings.HasPrefix(hash, Argon2Prefix) { + return compareHashAndPasswordArgon2(ctx, hash, password) + } else if strings.HasPrefix(hash, FirebaseScryptPrefix) { + return compareHashAndPasswordFirebaseScrypt(ctx, hash, password) + } + + // assume bcrypt + hashCost, err := bcrypt.Cost([]byte(hash)) + if err != nil { + return err + } + + attributes := []attribute.KeyValue{ + attribute.String("alg", "bcrypt"), + attribute.Int("bcrypt_cost", hashCost), + } + + compareHashAndPasswordSubmittedCounter.Add(ctx, 1, metric.WithAttributes(attributes...)) + defer func() { + attributes = append(attributes, attribute.Bool( + "match", + !errors.Is(err, bcrypt.ErrMismatchedHashAndPassword), + )) + + compareHashAndPasswordCompletedCounter.Add(ctx, 1, metric.WithAttributes(attributes...)) + }() + + err = bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) + return err +} + +// GenerateFromPassword generates a password hash from a +// password, using PasswordHashCost. Context can be used to cancel the hashing +// if the algorithm supports it. +func GenerateFromPassword(ctx context.Context, password string) (string, error) { + hashCost := bcrypt.DefaultCost + + switch PasswordHashCost { + case QuickHashCost: + hashCost = bcrypt.MinCost + } + + attributes := []attribute.KeyValue{ + attribute.String("alg", "bcrypt"), + attribute.Int("bcrypt_cost", hashCost), + } + + generateFromPasswordSubmittedCounter.Add(ctx, 1, metric.WithAttributes(attributes...)) + defer generateFromPasswordCompletedCounter.Add(ctx, 1, metric.WithAttributes(attributes...)) + + hash := must(bcrypt.GenerateFromPassword([]byte(password), hashCost)) + + return string(hash), nil +} + +func GeneratePassword(requiredChars []string, length int) string { + passwordBuilder := strings.Builder{} + passwordBuilder.Grow(length) + + // Add required characters + for _, group := range requiredChars { + if len(group) > 0 { + randomIndex := secureRandomInt(len(group)) + + passwordBuilder.WriteByte(group[randomIndex]) + } + } + + // Define a default character set for random generation (if needed) + const allChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + + // Fill the rest of the password + for passwordBuilder.Len() < length { + randomIndex := secureRandomInt(len(allChars)) + passwordBuilder.WriteByte(allChars[randomIndex]) + } + + // Convert to byte slice for shuffling + passwordBytes := []byte(passwordBuilder.String()) + + // Secure shuffling + for i := len(passwordBytes) - 1; i > 0; i-- { + j := secureRandomInt(i + 1) + + passwordBytes[i], passwordBytes[j] = passwordBytes[j], passwordBytes[i] + } + + return string(passwordBytes) +} diff --git a/auth_v2.169.0/internal/crypto/password_test.go b/auth_v2.169.0/internal/crypto/password_test.go new file mode 100644 index 0000000..289c9fe --- /dev/null +++ b/auth_v2.169.0/internal/crypto/password_test.go @@ -0,0 +1,178 @@ +package crypto + +import ( + "context" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestArgon2(t *testing.T) { + // all of these hash the `test` string with various parameters + + examples := []string{ + "$argon2i$v=19$m=16,t=2,p=1$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + "$argon2id$v=19$m=32,t=3,p=2$SFVpOWJ0eXhjRzVkdGN1RQ$RXnb8rh7LaDcn07xsssqqulZYXOM/EUCEFMVcAcyYVk", + } + + for _, example := range examples { + assert.NoError(t, CompareHashAndPassword(context.Background(), example, "test")) + } + + for _, example := range examples { + assert.Error(t, CompareHashAndPassword(context.Background(), example, "test1")) + } + + negativeExamples := []string{ + // 2d + "$argon2d$v=19$m=16,t=2,p=1$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + // v=16 + "$argon2id$v=16$m=16,t=2,p=1$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + // data + "$argon2id$v=19$m=16,t=2,p=1,data=abc$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + // keyid + "$argon2id$v=19$m=16,t=2,p=1,keyid=abc$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + // m larger than 32 bits + "$argon2id$v=19$m=4294967297,t=2,p=1$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + // t larger than 32 bits + "$argon2id$v=19$m=16,t=4294967297,p=1$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + // p larger than 8 bits + "$argon2id$v=19$m=16,t=2,p=256$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + // salt not Base64 + "$argon2id$v=19$m=16,t=2,p=1$!!!$NfEnUOuUpb7F2fQkgFUG4g", + // hash not Base64 + "$argon2id$v=19$m=16,t=2,p=1$bGJRWThNOHJJTVBSdHl2dQ$!!!", + // salt empty + "$argon2id$v=19$m=16,t=2,p=1$$NfEnUOuUpb7F2fQkgFUG4g", + // hash empty + "$argon2id$v=19$m=16,t=2,p=1$bGJRWThNOHJJTVBSdHl2dQ$", + } + + for _, example := range negativeExamples { + assert.Error(t, CompareHashAndPassword(context.Background(), example, "test")) + } +} + +func TestGeneratePassword(t *testing.T) { + tests := []struct { + name string + requiredChars []string + length int + }{ + { + name: "Valid password generation", + requiredChars: []string{"ABC", "123", "@#$"}, + length: 12, + }, + { + name: "Empty required chars", + requiredChars: []string{}, + length: 8, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GeneratePassword(tt.requiredChars, tt.length) + + if len(got) != tt.length { + t.Errorf("GeneratePassword() returned password of length %d, want %d", len(got), tt.length) + } + + // Check if all required characters are present + for _, chars := range tt.requiredChars { + found := false + for _, c := range got { + if strings.ContainsRune(chars, c) { + found = true + break + } + } + if !found && len(chars) > 0 { + t.Errorf("GeneratePassword() missing required character from set %s", chars) + } + } + }) + } + + // Check for duplicates passwords + passwords := make(map[string]bool) + for i := 0; i < 30; i++ { + p := GeneratePassword([]string{"ABC", "123", "@#$"}, 30) + + if passwords[p] { + t.Errorf("GeneratePassword() generated duplicate password: %s", p) + } + passwords[p] = true + } +} + +func TestFirebaseScrypt(t *testing.T) { + // all of these use the `mytestpassword` string as the valid one + + examples := []string{ + "$fbscrypt$v=1,n=14,r=8,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + } + + for _, example := range examples { + assert.NoError(t, CompareHashAndPassword(context.Background(), example, "mytestpassword")) + } + + for _, example := range examples { + assert.Error(t, CompareHashAndPassword(context.Background(), example, "mytestpassword1")) + } + + negativeExamples := []string{ + // v not 1 + "$fbscrypt$v=2,n=14,r=8,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // n not 32 bits + "$fbscrypt$v=1,n=4294967297,r=8,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // n is 0 + "$fbscrypt$v=1,n=0,r=8,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // rounds is not 64 bits + "$fbscrypt$v=1,n=14,r=18446744073709551617,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // rounds is 0 + "$fbscrypt$v=1,n=14,r=0,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // threads is not 8 bits + "$fbscrypt$v=1,n=14,r=8,p=256,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // threads is 0 + "$fbscrypt$v=1,n=14,r=8,p=0,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // hash is not base64 + "$fbscrypt$v=1,n=14,r=8,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$!!!", + // salt is not base64 + "$fbscrypt$v=1,n=14,r=8,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$!!!$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // signer key is not base64 + "$fbscrypt$v=1,n=14,r=8,p=1,ss=Bw==,sk=!!!$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + // salt separator is not base64 + "$fbscrypt$v=1,n=14,r=8,p=1,ss=!!!,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$zKVTMvnWVw5BBOZNUdnsalx4c4c7y/w7IS5p6Ut2+CfEFFlz37J9huyQfov4iizN8dbjvEJlM5tQaJP84+hfTw==", + } + + for _, example := range negativeExamples { + assert.Error(t, CompareHashAndPassword(context.Background(), example, "mytestpassword")) + } +} + +func TestBcrypt(t *testing.T) { + // all use the `test` password + + examples := []string{ + "$2y$04$mIJxfrCaEI3GukZe11CiXublhEFanu5.ododkll1WphfSp6pn4zIu", + "$2y$10$srNl09aPtc2qr.0Vl.NtjekJRt/NxRxYQm3qd3OvfcKsJgVnr6.Ve", + } + + for _, example := range examples { + assert.NoError(t, CompareHashAndPassword(context.Background(), example, "test")) + } + + for _, example := range examples { + assert.Error(t, CompareHashAndPassword(context.Background(), example, "test1")) + } + + negativeExamples := []string{ + "not-a-hash", + } + for _, example := range negativeExamples { + assert.Error(t, CompareHashAndPassword(context.Background(), example, "test")) + } +} diff --git a/auth_v2.169.0/internal/crypto/utils.go b/auth_v2.169.0/internal/crypto/utils.go new file mode 100644 index 0000000..a6b38b8 --- /dev/null +++ b/auth_v2.169.0/internal/crypto/utils.go @@ -0,0 +1,9 @@ +package crypto + +func must[T any](a T, err error) T { + if err != nil { + panic(err) + } + + return a +} diff --git a/auth_v2.169.0/internal/crypto/utils_test.go b/auth_v2.169.0/internal/crypto/utils_test.go new file mode 100644 index 0000000..1aeeab8 --- /dev/null +++ b/auth_v2.169.0/internal/crypto/utils_test.go @@ -0,0 +1,14 @@ +package crypto + +import ( + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestMust(t *testing.T) { + require.Panics(t, func() { + must(123, errors.New("panic")) + }) +} diff --git a/auth_v2.169.0/internal/hooks/auth_hooks.go b/auth_v2.169.0/internal/hooks/auth_hooks.go new file mode 100644 index 0000000..1b881d3 --- /dev/null +++ b/auth_v2.169.0/internal/hooks/auth_hooks.go @@ -0,0 +1,220 @@ +package hooks + +import ( + "github.com/gofrs/uuid" + "github.com/golang-jwt/jwt/v5" + "github.com/supabase/auth/internal/mailer" + "github.com/supabase/auth/internal/models" +) + +type HookType string + +const ( + PostgresHook HookType = "pg-functions" +) + +const ( + // In Miliseconds + DefaultTimeout = 2000 +) + +// Hook Names +const ( + HookRejection = "reject" +) + +type HTTPHookInput interface { + IsHTTPHook() +} + +type HookOutput interface { + IsError() bool + Error() string +} + +// TODO(joel): Move this to phone package +type SMS struct { + OTP string `json:"otp,omitempty"` + SMSType string `json:"sms_type,omitempty"` +} + +// #nosec +const MinimumViableTokenSchema = `{ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "aud": { + "type": ["string", "array"] + }, + "exp": { + "type": "integer" + }, + "jti": { + "type": "string" + }, + "iat": { + "type": "integer" + }, + "iss": { + "type": "string" + }, + "nbf": { + "type": "integer" + }, + "sub": { + "type": "string" + }, + "email": { + "type": "string" + }, + "phone": { + "type": "string" + }, + "app_metadata": { + "type": "object", + "additionalProperties": true + }, + "user_metadata": { + "type": "object", + "additionalProperties": true + }, + "role": { + "type": "string" + }, + "aal": { + "type": "string" + }, + "amr": { + "type": "array", + "items": { + "type": "object" + } + }, + "session_id": { + "type": "string" + } + }, + "required": ["aud", "exp", "iat", "sub", "email", "phone", "role", "aal", "session_id", "is_anonymous"] +}` + +// AccessTokenClaims is a struct thats used for JWT claims +type AccessTokenClaims struct { + jwt.RegisteredClaims + Email string `json:"email"` + Phone string `json:"phone"` + AppMetaData map[string]interface{} `json:"app_metadata"` + UserMetaData map[string]interface{} `json:"user_metadata"` + Role string `json:"role"` + AuthenticatorAssuranceLevel string `json:"aal,omitempty"` + AuthenticationMethodReference []models.AMREntry `json:"amr,omitempty"` + SessionId string `json:"session_id,omitempty"` + IsAnonymous bool `json:"is_anonymous"` +} + +type MFAVerificationAttemptInput struct { + UserID uuid.UUID `json:"user_id"` + FactorID uuid.UUID `json:"factor_id"` + FactorType string `json:"factor_type"` + Valid bool `json:"valid"` +} + +type MFAVerificationAttemptOutput struct { + Decision string `json:"decision"` + Message string `json:"message"` + HookError AuthHookError `json:"error"` +} + +type PasswordVerificationAttemptInput struct { + UserID uuid.UUID `json:"user_id"` + Valid bool `json:"valid"` +} + +type PasswordVerificationAttemptOutput struct { + Decision string `json:"decision"` + Message string `json:"message"` + ShouldLogoutUser bool `json:"should_logout_user"` + HookError AuthHookError `json:"error"` +} + +type CustomAccessTokenInput struct { + UserID uuid.UUID `json:"user_id"` + Claims *AccessTokenClaims `json:"claims"` + AuthenticationMethod string `json:"authentication_method"` +} + +type CustomAccessTokenOutput struct { + Claims map[string]interface{} `json:"claims"` + HookError AuthHookError `json:"error,omitempty"` +} + +type SendSMSInput struct { + User *models.User `json:"user,omitempty"` + SMS SMS `json:"sms,omitempty"` +} + +type SendSMSOutput struct { + HookError AuthHookError `json:"error,omitempty"` +} + +type SendEmailInput struct { + User *models.User `json:"user"` + EmailData mailer.EmailData `json:"email_data"` +} + +type SendEmailOutput struct { + HookError AuthHookError `json:"error,omitempty"` +} + +func (mf *MFAVerificationAttemptOutput) IsError() bool { + return mf.HookError.Message != "" +} + +func (mf *MFAVerificationAttemptOutput) Error() string { + return mf.HookError.Message +} + +func (p *PasswordVerificationAttemptOutput) IsError() bool { + return p.HookError.Message != "" +} + +func (p *PasswordVerificationAttemptOutput) Error() string { + return p.HookError.Message +} + +func (ca *CustomAccessTokenOutput) IsError() bool { + return ca.HookError.Message != "" +} + +func (ca *CustomAccessTokenOutput) Error() string { + return ca.HookError.Message +} + +func (cs *SendSMSOutput) IsError() bool { + return cs.HookError.Message != "" +} + +func (cs *SendSMSOutput) Error() string { + return cs.HookError.Message +} + +func (cs *SendEmailOutput) IsError() bool { + return cs.HookError.Message != "" +} + +func (cs *SendEmailOutput) Error() string { + return cs.HookError.Message +} + +type AuthHookError struct { + HTTPCode int `json:"http_code,omitempty"` + Message string `json:"message,omitempty"` +} + +func (a *AuthHookError) Error() string { + return a.Message +} + +const ( + DefaultMFAHookRejectionMessage = "Further MFA verification attempts will be rejected." + DefaultPasswordHookRejectionMessage = "Further password verification attempts will be rejected." +) diff --git a/auth_v2.169.0/internal/mailer/mailer.go b/auth_v2.169.0/internal/mailer/mailer.go new file mode 100644 index 0000000..1499960 --- /dev/null +++ b/auth_v2.169.0/internal/mailer/mailer.go @@ -0,0 +1,93 @@ +package mailer + +import ( + "fmt" + "net/http" + "net/url" + + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +// Mailer defines the interface a mailer must implement. +type Mailer interface { + InviteMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error + ConfirmationMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error + RecoveryMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error + MagicLinkMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error + EmailChangeMail(r *http.Request, user *models.User, otpNew, otpCurrent, referrerURL string, externalURL *url.URL) error + ReauthenticateMail(r *http.Request, user *models.User, otp string) error + GetEmailActionLink(user *models.User, actionType, referrerURL string, externalURL *url.URL) (string, error) +} + +type EmailParams struct { + Token string + Type string + RedirectTo string +} + +type EmailData struct { + Token string `json:"token"` + TokenHash string `json:"token_hash"` + RedirectTo string `json:"redirect_to"` + EmailActionType string `json:"email_action_type"` + SiteURL string `json:"site_url"` + TokenNew string `json:"token_new"` + TokenHashNew string `json:"token_hash_new"` +} + +// NewMailer returns a new gotrue mailer +func NewMailer(globalConfig *conf.GlobalConfiguration) Mailer { + from := globalConfig.SMTP.FromAddress() + u, _ := url.ParseRequestURI(globalConfig.API.ExternalURL) + + var mailClient MailClient + if globalConfig.SMTP.Host == "" { + logrus.Infof("Noop mail client being used for %v", globalConfig.SiteURL) + mailClient = &noopMailClient{ + EmailValidator: newEmailValidator(globalConfig.Mailer), + } + } else { + mailClient = &MailmeMailer{ + Host: globalConfig.SMTP.Host, + Port: globalConfig.SMTP.Port, + User: globalConfig.SMTP.User, + Pass: globalConfig.SMTP.Pass, + LocalName: u.Hostname(), + From: from, + BaseURL: globalConfig.SiteURL, + Logger: logrus.StandardLogger(), + MailLogging: globalConfig.SMTP.LoggingEnabled, + EmailValidator: newEmailValidator(globalConfig.Mailer), + } + } + + return &TemplateMailer{ + SiteURL: globalConfig.SiteURL, + Config: globalConfig, + Mailer: mailClient, + } +} + +func withDefault(value, defaultValue string) string { + if value == "" { + return defaultValue + } + return value +} + +func getPath(filepath string, params *EmailParams) (*url.URL, error) { + path := &url.URL{} + if filepath != "" { + if p, err := url.Parse(filepath); err != nil { + return nil, err + } else { + path = p + } + } + if params != nil { + path.RawQuery = fmt.Sprintf("token=%s&type=%s&redirect_to=%s", url.QueryEscape(params.Token), url.QueryEscape(params.Type), encodeRedirectURL(params.RedirectTo)) + } + return path, nil +} diff --git a/auth_v2.169.0/internal/mailer/mailer_test.go b/auth_v2.169.0/internal/mailer/mailer_test.go new file mode 100644 index 0000000..290d65d --- /dev/null +++ b/auth_v2.169.0/internal/mailer/mailer_test.go @@ -0,0 +1,87 @@ +package mailer + +import ( + "net/url" + "regexp" + "testing" + + "github.com/stretchr/testify/assert" +) + +var urlRegexp = regexp.MustCompile(`^https?://[^/]+`) + +func enforceRelativeURL(url string) string { + return urlRegexp.ReplaceAllString(url, "") +} + +func TestGetPath(t *testing.T) { + params := EmailParams{ + Token: "token", + Type: "signup", + RedirectTo: "https://example.com", + } + cases := []struct { + SiteURL string + Path string + Params *EmailParams + Expected string + }{ + { + SiteURL: "https://test.example.com", + Path: "/templates/confirm.html", + Params: nil, + Expected: "https://test.example.com/templates/confirm.html", + }, + { + SiteURL: "https://test.example.com/removedpath", + Path: "/templates/confirm.html", + Params: nil, + Expected: "https://test.example.com/templates/confirm.html", + }, + { + SiteURL: "https://test.example.com/", + Path: "/trailingslash/", + Params: nil, + Expected: "https://test.example.com/trailingslash/", + }, + { + SiteURL: "https://test.example.com", + Path: "f", + Params: ¶ms, + Expected: "https://test.example.com/f?token=token&type=signup&redirect_to=https://example.com", + }, + { + SiteURL: "https://test.example.com", + Path: "", + Params: ¶ms, + Expected: "https://test.example.com?token=token&type=signup&redirect_to=https://example.com", + }, + } + + for _, c := range cases { + u, err := url.ParseRequestURI(c.SiteURL) + assert.NoError(t, err, "error parsing URI request") + + path, err := getPath(c.Path, c.Params) + + assert.NoError(t, err) + assert.Equal(t, c.Expected, u.ResolveReference(path).String()) + } +} + +func TestRelativeURL(t *testing.T) { + cases := []struct { + URL string + Expected string + }{ + {"https://test.example.com", ""}, + {"http://test.example.com", ""}, + {"test.example.com", "test.example.com"}, + {"/some/path#fragment", "/some/path#fragment"}, + } + + for _, c := range cases { + res := enforceRelativeURL(c.URL) + assert.Equal(t, c.Expected, res, c.URL) + } +} diff --git a/auth_v2.169.0/internal/mailer/mailme.go b/auth_v2.169.0/internal/mailer/mailme.go new file mode 100644 index 0000000..20ff177 --- /dev/null +++ b/auth_v2.169.0/internal/mailer/mailme.go @@ -0,0 +1,230 @@ +package mailer + +import ( + "bytes" + "context" + "errors" + "html/template" + "io" + "log" + "net/http" + "strings" + "sync" + "time" + + "gopkg.in/gomail.v2" + + "github.com/sirupsen/logrus" +) + +// TemplateRetries is the amount of time MailMe will try to fetch a URL before giving up +const TemplateRetries = 3 + +// TemplateExpiration is the time period that the template will be cached for +const TemplateExpiration = 10 * time.Second + +// MailmeMailer lets MailMe send templated mails +type MailmeMailer struct { + From string + Host string + Port int + User string + Pass string + BaseURL string + LocalName string + FuncMap template.FuncMap + cache *TemplateCache + Logger logrus.FieldLogger + MailLogging bool + EmailValidator *EmailValidator +} + +// Mail sends a templated mail. It will try to load the template from a URL, and +// otherwise fall back to the default +func (m *MailmeMailer) Mail( + ctx context.Context, + to, subjectTemplate, templateURL, defaultTemplate string, + templateData map[string]interface{}, + headers map[string][]string, + typ string, +) error { + if m.FuncMap == nil { + m.FuncMap = map[string]interface{}{} + } + if m.cache == nil { + m.cache = &TemplateCache{ + templates: map[string]*MailTemplate{}, + funcMap: m.FuncMap, + logger: m.Logger, + } + } + + if m.EmailValidator != nil { + if err := m.EmailValidator.Validate(ctx, to); err != nil { + return err + } + } + + tmp, err := template.New("Subject").Funcs(template.FuncMap(m.FuncMap)).Parse(subjectTemplate) + if err != nil { + return err + } + + subject := &bytes.Buffer{} + err = tmp.Execute(subject, templateData) + if err != nil { + return err + } + + body, err := m.MailBody(templateURL, defaultTemplate, templateData) + if err != nil { + return err + } + + mail := gomail.NewMessage() + mail.SetHeader("From", m.From) + mail.SetHeader("To", to) + mail.SetHeader("Subject", subject.String()) + + for k, v := range headers { + if v != nil { + mail.SetHeader(k, v...) + } + } + + mail.SetBody("text/html", body) + + dial := gomail.NewDialer(m.Host, m.Port, m.User, m.Pass) + if m.LocalName != "" { + dial.LocalName = m.LocalName + } + + if m.MailLogging { + defer func() { + fields := logrus.Fields{ + "event": "mail.send", + "mail_type": typ, + "mail_from": m.From, + "mail_to": to, + } + m.Logger.WithFields(fields).Info("mail.send") + }() + } + if err := dial.DialAndSend(mail); err != nil { + return err + } + return nil +} + +type MailTemplate struct { + tmp *template.Template + expiresAt time.Time +} + +type TemplateCache struct { + templates map[string]*MailTemplate + mutex sync.Mutex + funcMap template.FuncMap + logger logrus.FieldLogger +} + +func (t *TemplateCache) Get(url string) (*template.Template, error) { + cached, ok := t.templates[url] + if ok && (cached.expiresAt.Before(time.Now())) { + return cached.tmp, nil + } + data, err := t.fetchTemplate(url, TemplateRetries) + if err != nil { + return nil, err + } + return t.Set(url, data, TemplateExpiration) +} + +func (t *TemplateCache) Set(key, value string, expirationTime time.Duration) (*template.Template, error) { + parsed, err := template.New(key).Funcs(t.funcMap).Parse(value) + if err != nil { + return nil, err + } + + cached := &MailTemplate{ + tmp: parsed, + expiresAt: time.Now().Add(expirationTime), + } + t.mutex.Lock() + t.templates[key] = cached + t.mutex.Unlock() + return parsed, nil +} + +func (t *TemplateCache) fetchTemplate(url string, triesLeft int) (string, error) { + client := &http.Client{ + Timeout: 10 * time.Second, + } + + resp, err := client.Get(url) + if err != nil && triesLeft > 0 { + return t.fetchTemplate(url, triesLeft-1) + } + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode == 200 { // OK + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil && triesLeft > 0 { + return t.fetchTemplate(url, triesLeft-1) + } + if err != nil { + return "", err + } + return string(bodyBytes), err + } + if triesLeft > 0 { + return t.fetchTemplate(url, triesLeft-1) + } + return "", errors.New("mailer: unable to fetch mail template") +} + +func (m *MailmeMailer) MailBody(url string, defaultTemplate string, data map[string]interface{}) (string, error) { + if m.FuncMap == nil { + m.FuncMap = map[string]interface{}{} + } + if m.cache == nil { + m.cache = &TemplateCache{templates: map[string]*MailTemplate{}, funcMap: m.FuncMap} + } + + var temp *template.Template + var err error + + if url != "" { + var absoluteURL string + if strings.HasPrefix(url, "http") { + absoluteURL = url + } else { + absoluteURL = m.BaseURL + url + } + temp, err = m.cache.Get(absoluteURL) + if err != nil { + log.Printf("Error loading template from %v: %v\n", url, err) + } + } + + if temp == nil { + cached, ok := m.cache.templates[url] + if ok { + temp = cached.tmp + } else { + temp, err = m.cache.Set(url, defaultTemplate, 0) + if err != nil { + return "", err + } + } + } + + buf := &bytes.Buffer{} + err = temp.Execute(buf, data) + if err != nil { + return "", err + } + return buf.String(), nil +} diff --git a/auth_v2.169.0/internal/mailer/noop.go b/auth_v2.169.0/internal/mailer/noop.go new file mode 100644 index 0000000..0e0e3bf --- /dev/null +++ b/auth_v2.169.0/internal/mailer/noop.go @@ -0,0 +1,28 @@ +package mailer + +import ( + "context" + "errors" +) + +type noopMailClient struct { + EmailValidator *EmailValidator +} + +func (m *noopMailClient) Mail( + ctx context.Context, + to, subjectTemplate, templateURL, defaultTemplate string, + templateData map[string]interface{}, + headers map[string][]string, + typ string, +) error { + if to == "" { + return errors.New("to field cannot be empty") + } + if m.EmailValidator != nil { + if err := m.EmailValidator.Validate(ctx, to); err != nil { + return err + } + } + return nil +} diff --git a/auth_v2.169.0/internal/mailer/template.go b/auth_v2.169.0/internal/mailer/template.go new file mode 100644 index 0000000..59a4854 --- /dev/null +++ b/auth_v2.169.0/internal/mailer/template.go @@ -0,0 +1,420 @@ +package mailer + +import ( + "context" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/models" +) + +type MailRequest struct { + To string + SubjectTemplate string + TemplateURL string + DefaultTemplate string + TemplateData map[string]interface{} + Headers map[string][]string + Type string +} + +type MailClient interface { + Mail( + ctx context.Context, + to string, + subjectTemplate string, + templateURL string, + defaultTemplate string, + templateData map[string]interface{}, + headers map[string][]string, + typ string, + ) error +} + +// TemplateMailer will send mail and use templates from the site for easy mail styling +type TemplateMailer struct { + SiteURL string + Config *conf.GlobalConfiguration + Mailer MailClient +} + +func encodeRedirectURL(referrerURL string) string { + if len(referrerURL) > 0 { + if strings.ContainsAny(referrerURL, "&=#") { + // if the string contains &, = or # it has not been URL + // encoded by the caller, which means it should be URL + // encoded by us otherwise, it should be taken as-is + referrerURL = url.QueryEscape(referrerURL) + } + } + return referrerURL +} + +const ( + SignupVerification = "signup" + RecoveryVerification = "recovery" + InviteVerification = "invite" + MagicLinkVerification = "magiclink" + EmailChangeVerification = "email_change" + EmailOTPVerification = "email" + EmailChangeCurrentVerification = "email_change_current" + EmailChangeNewVerification = "email_change_new" + ReauthenticationVerification = "reauthentication" +) + +const defaultInviteMail = `

You have been invited

+ +

You have been invited to create a user on {{ .SiteURL }}. Follow this link to accept the invite:

+

Accept the invite

+

Alternatively, enter the code: {{ .Token }}

` + +const defaultConfirmationMail = `

Confirm your email

+ +

Follow this link to confirm your email:

+

Confirm your email address

+

Alternatively, enter the code: {{ .Token }}

+` + +const defaultRecoveryMail = `

Reset password

+ +

Follow this link to reset the password for your user:

+

Reset password

+

Alternatively, enter the code: {{ .Token }}

` + +const defaultMagicLinkMail = `

Magic Link

+ +

Follow this link to login:

+

Log In

+

Alternatively, enter the code: {{ .Token }}

` + +const defaultEmailChangeMail = `

Confirm email address change

+ +

Follow this link to confirm the update of your email address from {{ .Email }} to {{ .NewEmail }}:

+

Change email address

+

Alternatively, enter the code: {{ .Token }}

` + +const defaultReauthenticateMail = `

Confirm reauthentication

+ +

Enter the code: {{ .Token }}

` + +func (m *TemplateMailer) Headers(messageType string) map[string][]string { + originalHeaders := m.Config.SMTP.NormalizedHeaders() + + if originalHeaders == nil { + return nil + } + + headers := make(map[string][]string, len(originalHeaders)) + + for header, values := range originalHeaders { + replacedValues := make([]string, 0, len(values)) + + if header == "" { + continue + } + + for _, value := range values { + if value == "" { + continue + } + + // TODO: in the future, use a templating engine to add more contextual data available to headers + if strings.Contains(value, "$messageType") { + replacedValues = append(replacedValues, strings.ReplaceAll(value, "$messageType", messageType)) + } else { + replacedValues = append(replacedValues, value) + } + } + + headers[header] = replacedValues + } + + return headers +} + +// InviteMail sends a invite mail to a new user +func (m *TemplateMailer) InviteMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { + path, err := getPath(m.Config.Mailer.URLPaths.Invite, &EmailParams{ + Token: user.ConfirmationToken, + Type: "invite", + RedirectTo: referrerURL, + }) + + if err != nil { + return err + } + + data := map[string]interface{}{ + "SiteURL": m.Config.SiteURL, + "ConfirmationURL": externalURL.ResolveReference(path).String(), + "Email": user.Email, + "Token": otp, + "TokenHash": user.ConfirmationToken, + "Data": user.UserMetaData, + "RedirectTo": referrerURL, + } + + return m.Mailer.Mail( + r.Context(), + user.GetEmail(), + withDefault(m.Config.Mailer.Subjects.Invite, "You have been invited"), + m.Config.Mailer.Templates.Invite, + defaultInviteMail, + data, + m.Headers("invite"), + "invite", + ) +} + +// ConfirmationMail sends a signup confirmation mail to a new user +func (m *TemplateMailer) ConfirmationMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { + path, err := getPath(m.Config.Mailer.URLPaths.Confirmation, &EmailParams{ + Token: user.ConfirmationToken, + Type: "signup", + RedirectTo: referrerURL, + }) + if err != nil { + return err + } + + data := map[string]interface{}{ + "SiteURL": m.Config.SiteURL, + "ConfirmationURL": externalURL.ResolveReference(path).String(), + "Email": user.Email, + "Token": otp, + "TokenHash": user.ConfirmationToken, + "Data": user.UserMetaData, + "RedirectTo": referrerURL, + } + + return m.Mailer.Mail( + r.Context(), + user.GetEmail(), + withDefault(m.Config.Mailer.Subjects.Confirmation, "Confirm Your Email"), + m.Config.Mailer.Templates.Confirmation, + defaultConfirmationMail, + data, + m.Headers("confirm"), + "confirm", + ) +} + +// ReauthenticateMail sends a reauthentication mail to an authenticated user +func (m *TemplateMailer) ReauthenticateMail(r *http.Request, user *models.User, otp string) error { + data := map[string]interface{}{ + "SiteURL": m.Config.SiteURL, + "Email": user.Email, + "Token": otp, + "Data": user.UserMetaData, + } + + return m.Mailer.Mail( + r.Context(), + user.GetEmail(), + withDefault(m.Config.Mailer.Subjects.Reauthentication, "Confirm reauthentication"), + m.Config.Mailer.Templates.Reauthentication, + defaultReauthenticateMail, + data, + m.Headers("reauthenticate"), + "reauthenticate", + ) +} + +// EmailChangeMail sends an email change confirmation mail to a user +func (m *TemplateMailer) EmailChangeMail(r *http.Request, user *models.User, otpNew, otpCurrent, referrerURL string, externalURL *url.URL) error { + type Email struct { + Address string + Otp string + TokenHash string + Subject string + Template string + } + emails := []Email{ + { + Address: user.EmailChange, + Otp: otpNew, + TokenHash: user.EmailChangeTokenNew, + Subject: withDefault(m.Config.Mailer.Subjects.EmailChange, "Confirm Email Change"), + Template: m.Config.Mailer.Templates.EmailChange, + }, + } + + currentEmail := user.GetEmail() + if m.Config.Mailer.SecureEmailChangeEnabled && currentEmail != "" { + emails = append(emails, Email{ + Address: currentEmail, + Otp: otpCurrent, + TokenHash: user.EmailChangeTokenCurrent, + Subject: withDefault(m.Config.Mailer.Subjects.Confirmation, "Confirm Email Address"), + Template: m.Config.Mailer.Templates.EmailChange, + }) + } + + ctx, cancel := context.WithCancel(r.Context()) + defer cancel() + + errors := make(chan error, len(emails)) + for _, email := range emails { + path, err := getPath( + m.Config.Mailer.URLPaths.EmailChange, + &EmailParams{ + Token: email.TokenHash, + Type: "email_change", + RedirectTo: referrerURL, + }, + ) + if err != nil { + return err + } + go func(address, token, tokenHash, template string) { + data := map[string]interface{}{ + "SiteURL": m.Config.SiteURL, + "ConfirmationURL": externalURL.ResolveReference(path).String(), + "Email": user.GetEmail(), + "NewEmail": user.EmailChange, + "Token": token, + "TokenHash": tokenHash, + "SendingTo": address, + "Data": user.UserMetaData, + "RedirectTo": referrerURL, + } + errors <- m.Mailer.Mail( + ctx, + address, + withDefault(m.Config.Mailer.Subjects.EmailChange, "Confirm Email Change"), + template, + defaultEmailChangeMail, + data, + m.Headers("email_change"), + "email_change", + ) + }(email.Address, email.Otp, email.TokenHash, email.Template) + } + + for i := 0; i < len(emails); i++ { + e := <-errors + if e != nil { + return e + } + } + return nil +} + +// RecoveryMail sends a password recovery mail +func (m *TemplateMailer) RecoveryMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { + path, err := getPath(m.Config.Mailer.URLPaths.Recovery, &EmailParams{ + Token: user.RecoveryToken, + Type: "recovery", + RedirectTo: referrerURL, + }) + if err != nil { + return err + } + data := map[string]interface{}{ + "SiteURL": m.Config.SiteURL, + "ConfirmationURL": externalURL.ResolveReference(path).String(), + "Email": user.Email, + "Token": otp, + "TokenHash": user.RecoveryToken, + "Data": user.UserMetaData, + "RedirectTo": referrerURL, + } + + return m.Mailer.Mail( + r.Context(), + user.GetEmail(), + withDefault(m.Config.Mailer.Subjects.Recovery, "Reset Your Password"), + m.Config.Mailer.Templates.Recovery, + defaultRecoveryMail, + data, + m.Headers("recovery"), + "recovery", + ) +} + +// MagicLinkMail sends a login link mail +func (m *TemplateMailer) MagicLinkMail(r *http.Request, user *models.User, otp, referrerURL string, externalURL *url.URL) error { + path, err := getPath(m.Config.Mailer.URLPaths.Recovery, &EmailParams{ + Token: user.RecoveryToken, + Type: "magiclink", + RedirectTo: referrerURL, + }) + if err != nil { + return err + } + + data := map[string]interface{}{ + "SiteURL": m.Config.SiteURL, + "ConfirmationURL": externalURL.ResolveReference(path).String(), + "Email": user.Email, + "Token": otp, + "TokenHash": user.RecoveryToken, + "Data": user.UserMetaData, + "RedirectTo": referrerURL, + } + + return m.Mailer.Mail( + r.Context(), + user.GetEmail(), + withDefault(m.Config.Mailer.Subjects.MagicLink, "Your Magic Link"), + m.Config.Mailer.Templates.MagicLink, + defaultMagicLinkMail, + data, + m.Headers("magiclink"), + "magiclink", + ) +} + +// GetEmailActionLink returns a magiclink, recovery or invite link based on the actionType passed. +func (m TemplateMailer) GetEmailActionLink(user *models.User, actionType, referrerURL string, externalURL *url.URL) (string, error) { + var err error + var path *url.URL + + switch actionType { + case "magiclink": + path, err = getPath(m.Config.Mailer.URLPaths.Recovery, &EmailParams{ + Token: user.RecoveryToken, + Type: "magiclink", + RedirectTo: referrerURL, + }) + case "recovery": + path, err = getPath(m.Config.Mailer.URLPaths.Recovery, &EmailParams{ + Token: user.RecoveryToken, + Type: "recovery", + RedirectTo: referrerURL, + }) + case "invite": + path, err = getPath(m.Config.Mailer.URLPaths.Invite, &EmailParams{ + Token: user.ConfirmationToken, + Type: "invite", + RedirectTo: referrerURL, + }) + case "signup": + path, err = getPath(m.Config.Mailer.URLPaths.Confirmation, &EmailParams{ + Token: user.ConfirmationToken, + Type: "signup", + RedirectTo: referrerURL, + }) + case "email_change_current": + path, err = getPath(m.Config.Mailer.URLPaths.EmailChange, &EmailParams{ + Token: user.EmailChangeTokenCurrent, + Type: "email_change", + RedirectTo: referrerURL, + }) + case "email_change_new": + path, err = getPath(m.Config.Mailer.URLPaths.EmailChange, &EmailParams{ + Token: user.EmailChangeTokenNew, + Type: "email_change", + RedirectTo: referrerURL, + }) + default: + return "", fmt.Errorf("invalid email action link type: %s", actionType) + } + if err != nil { + return "", err + } + return externalURL.ResolveReference(path).String(), nil +} diff --git a/auth_v2.169.0/internal/mailer/template_test.go b/auth_v2.169.0/internal/mailer/template_test.go new file mode 100644 index 0000000..f8fcd74 --- /dev/null +++ b/auth_v2.169.0/internal/mailer/template_test.go @@ -0,0 +1,65 @@ +package mailer + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/conf" +) + +func TestTemplateHeaders(t *testing.T) { + cases := []struct { + from string + typ string + exp map[string][]string + }{ + { + from: `{"x-supabase-project-ref": ["abcjrhohrqmvcpjpsyzc"]}`, + typ: "OTHER-TYPE", + exp: map[string][]string{ + "x-supabase-project-ref": {"abcjrhohrqmvcpjpsyzc"}, + }, + }, + + { + from: `{"X-Test-A": ["test-a", "test-b"], "X-Test-B": ["test-c", "abc $messageType"]}`, + typ: "TEST-MESSAGE-TYPE", + exp: map[string][]string{ + "X-Test-A": {"test-a", "test-b"}, + "X-Test-B": {"test-c", "abc TEST-MESSAGE-TYPE"}, + }, + }, + + { + from: `{"X-Test-A": ["test-a", "test-b"], "X-Test-B": ["test-c", "abc $messageType"]}`, + typ: "OTHER-TYPE", + exp: map[string][]string{ + "X-Test-A": {"test-a", "test-b"}, + "X-Test-B": {"test-c", "abc OTHER-TYPE"}, + }, + }, + + { + from: `{"X-Test-A": ["test-a", "test-b"], "X-Test-B": ["test-c", "abc $messageType"], "x-supabase-project-ref": ["abcjrhohrqmvcpjpsyzc"]}`, + typ: "OTHER-TYPE", + exp: map[string][]string{ + "X-Test-A": {"test-a", "test-b"}, + "X-Test-B": {"test-c", "abc OTHER-TYPE"}, + "x-supabase-project-ref": {"abcjrhohrqmvcpjpsyzc"}, + }, + }, + } + for _, tc := range cases { + mailer := TemplateMailer{ + Config: &conf.GlobalConfiguration{ + SMTP: conf.SMTPConfiguration{ + Headers: tc.from, + }, + }, + } + require.NoError(t, mailer.Config.SMTP.Validate()) + + hdrs := mailer.Headers(tc.typ) + require.Equal(t, hdrs, tc.exp) + } +} diff --git a/auth_v2.169.0/internal/mailer/validate.go b/auth_v2.169.0/internal/mailer/validate.go new file mode 100644 index 0000000..1827466 --- /dev/null +++ b/auth_v2.169.0/internal/mailer/validate.go @@ -0,0 +1,298 @@ +package mailer + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net" + "net/http" + "net/mail" + "strings" + "time" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/sync/errgroup" +) + +var invalidEmailMap = map[string]bool{ + + // People type these often enough to be special cased. + "test@gmail.com": true, + "example@gmail.com": true, + "someone@gmail.com": true, + "test@email.com": true, +} + +var invalidHostSuffixes = []string{ + + // These are a directly from Section 2 of RFC2606[1]. + // + // [1] https://www.rfc-editor.org/rfc/rfc2606.html#section-2 + ".test", + ".example", + ".invalid", + ".local", + ".localhost", +} + +var invalidHostMap = map[string]bool{ + + // These exist here too for when they are typed as "test@test" + "test": true, + "example": true, + "invalid": true, + "local": true, + "localhost": true, + + // These are commonly typed and have DNS records which cause a + // large enough volume of bounce backs to special case. + "test.com": true, + "example.com": true, + "example.net": true, + "example.org": true, + + // Hundreds of typos per day for this. + "gamil.com": true, + + // These are not email providers, but people often use them. + "anonymous.com": true, + "email.com": true, +} + +const ( + validateEmailTimeout = 3 * time.Second +) + +var ( + // We use the default resolver for this. + validateEmailResolver net.Resolver +) + +var ( + ErrInvalidEmailAddress = errors.New("invalid_email_address") + ErrInvalidEmailFormat = errors.New("invalid_email_format") + ErrInvalidEmailDNS = errors.New("invalid_email_dns") +) + +type EmailValidator struct { + extended bool + serviceURL string + serviceHeaders map[string][]string +} + +func newEmailValidator(mc conf.MailerConfiguration) *EmailValidator { + return &EmailValidator{ + extended: mc.EmailValidationExtended, + serviceURL: mc.EmailValidationServiceURL, + serviceHeaders: mc.GetEmailValidationServiceHeaders(), + } +} + +func (ev *EmailValidator) isExtendedEnabled() bool { return ev.extended } +func (ev *EmailValidator) isServiceEnabled() bool { return ev.serviceURL != "" } + +// Validate performs validation on the given email. +// +// When extended is true, returns a nil error in all cases but the following: +// - `email` cannot be parsed by mail.ParseAddress +// - `email` has a domain with no DNS configured +// +// When serviceURL AND serviceKey are non-empty strings it uses the remote +// service to determine if the email is valid. +func (ev *EmailValidator) Validate(ctx context.Context, email string) error { + if !ev.isExtendedEnabled() && !ev.isServiceEnabled() { + return nil + } + + // One of the two validation methods are enabled, set a timeout. + ctx, cancel := context.WithTimeout(ctx, validateEmailTimeout) + defer cancel() + + // Easier control flow here to always use errgroup, it has very little + // overhad in comparison to the network calls it makes. The reason + // we run both checks concurrently is to tighten the timeout without + // potentially missing a call to the validation service due to a + // dns timeout or something more nefarious like a honeypot dns entry. + g := new(errgroup.Group) + + // Validate the static rules first to prevent round trips on bad emails + // and to parse the host ahead of time. + if ev.isExtendedEnabled() { + + // First validate static checks such as format, known invalid hosts + // and any other network free checks. Running this check before we + // call the service will help reduce the number of calls with known + // invalid emails. + host, err := ev.validateStatic(email) + if err != nil { + return err + } + + // Start the goroutine to validate the host. + g.Go(func() error { return ev.validateHost(ctx, host) }) + } + + // If the service check is enabled we start a goroutine to run + // that check as well. + if ev.isServiceEnabled() { + g.Go(func() error { return ev.validateService(ctx, email) }) + } + return g.Wait() +} + +// validateStatic will validate the format and do the static checks before +// returning the host portion of the email. +func (ev *EmailValidator) validateStatic(email string) (string, error) { + if !ev.isExtendedEnabled() { + return "", nil + } + + ea, err := mail.ParseAddress(email) + if err != nil { + return "", ErrInvalidEmailFormat + } + + i := strings.LastIndex(ea.Address, "@") + if i == -1 { + return "", ErrInvalidEmailFormat + } + + // few static lookups that are typed constantly and known to be invalid. + if invalidEmailMap[email] { + return "", ErrInvalidEmailAddress + } + + host := email[i+1:] + if invalidHostMap[host] { + return "", ErrInvalidEmailDNS + } + + for i := range invalidHostSuffixes { + if strings.HasSuffix(host, invalidHostSuffixes[i]) { + return "", ErrInvalidEmailDNS + } + } + + name := email[:i] + if err := ev.validateProviders(name, host); err != nil { + return "", err + } + return host, nil +} + +func (ev *EmailValidator) validateService(ctx context.Context, email string) error { + if !ev.isServiceEnabled() { + return nil + } + + reqObject := struct { + EmailAddress string `json:"email"` + }{email} + + reqData, err := json.Marshal(&reqObject) + if err != nil { + return nil + } + + rdr := bytes.NewReader(reqData) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, ev.serviceURL, rdr) + if err != nil { + return nil + } + req.Header.Set("Content-Type", "application/json") + for name, vals := range ev.serviceHeaders { + for _, val := range vals { + req.Header.Set(name, val) + } + } + + res, err := http.DefaultClient.Do(req) + if err != nil { + return nil + } + defer res.Body.Close() + + resObject := struct { + Valid *bool `json:"valid"` + }{} + + if res.StatusCode/100 != 2 { + // we ignore the error here just in case the service is down + return nil + } + + dec := json.NewDecoder(io.LimitReader(res.Body, 1<<5)) + if err := dec.Decode(&resObject); err != nil { + return nil + } + + // If the object did not contain a valid key we consider the check as + // failed. We _must_ get a valid JSON response with a "valid" field. + if resObject.Valid == nil || *resObject.Valid { + return nil + } + + return ErrInvalidEmailAddress +} + +func (ev *EmailValidator) validateProviders(name, host string) error { + switch host { + case "gmail.com": + // Based on a sample of internal data, this reduces the number of + // bounced emails by 23%. Gmail documentation specifies that the + // min user name length is 6 characters. There may be some accounts + // from early gmail beta with shorter email addresses, but I think + // this reduces bounce rates enough to be worth adding for now. + if len(name) < 6 { + return ErrInvalidEmailAddress + } + } + return nil +} + +func (ev *EmailValidator) validateHost(ctx context.Context, host string) error { + _, err := validateEmailResolver.LookupMX(ctx, host) + if !isHostNotFound(err) { + return nil + } + + _, err = validateEmailResolver.LookupHost(ctx, host) + if !isHostNotFound(err) { + return nil + } + + // No addrs or mx records were found + return ErrInvalidEmailDNS +} + +func isHostNotFound(err error) bool { + if err == nil { + // We had no err, so we treat it as valid. We don't check the mx records + // because RFC 5321 specifies that if an empty list of MX's are returned + // the host should be treated as the MX[1]. + // + // See section 2 and 3 of: https://www.rfc-editor.org/rfc/rfc2606 + // [1] https://www.rfc-editor.org/rfc/rfc5321.html#section-5.1 + return false + } + + // No names present, we will try to get a positive assertion that the + // domain is not configured to receive email. + var dnsError *net.DNSError + if !errors.As(err, &dnsError) { + // We will be unable to determine with absolute certainy the email was + // invalid so we will err on the side of caution and return nil. + return false + } + + // The type of err is dnsError, inspect it to see if we can be certain + // the domain has no mx records currently. For this we require that + // the error was not temporary or a timeout. If those are both false + // we trust the value in IsNotFound. + if !dnsError.IsTemporary && !dnsError.IsTimeout && dnsError.IsNotFound { + return true + } + return false +} diff --git a/auth_v2.169.0/internal/mailer/validate_test.go b/auth_v2.169.0/internal/mailer/validate_test.go new file mode 100644 index 0000000..e1a86c2 --- /dev/null +++ b/auth_v2.169.0/internal/mailer/validate_test.go @@ -0,0 +1,287 @@ +package mailer + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/conf" +) + +func TestEmalValidatorService(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Second*60) + defer cancel() + + testResVal := new(atomic.Value) + testResVal.Store(`{"valid": true}`) + + testHdrsVal := new(atomic.Value) + testHdrsVal.Store(map[string]string{"apikey": "test"}) + + // testHeaders := map[string][]string{"apikey": []string{"test"}} + testHeaders := `{"apikey": ["test"]}` + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + key := r.Header.Get("apikey") + if key == "" { + fmt.Fprintln(w, `{"error": true}`) + return + } + + fmt.Fprintln(w, testResVal.Load().(string)) + })) + defer ts.Close() + + // Return nil err from service + // when svc and extended checks both report email as valid + { + testResVal.Store(`{"valid": true}`) + cfg := conf.MailerConfiguration{ + EmailValidationExtended: true, + EmailValidationServiceURL: ts.URL, + EmailValidationServiceHeaders: testHeaders, + } + if err := cfg.Validate(); err != nil { + t.Fatal(err) + } + + ev := newEmailValidator(cfg) + err := ev.Validate(ctx, "chris.stockton@supabase.io") + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + } + + // Return nil err from service when + // extended is disabled for a known invalid address + // service reports valid + { + testResVal.Store(`{"valid": true}`) + + cfg := conf.MailerConfiguration{ + EmailValidationExtended: false, + EmailValidationServiceURL: ts.URL, + EmailValidationServiceHeaders: testHeaders, + } + if err := cfg.Validate(); err != nil { + t.Fatal(err) + } + + ev := newEmailValidator(cfg) + err := ev.Validate(ctx, "test@gmail.com") + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + } + + // Return nil err from service when + // extended is disabled for a known invalid address + // service is disabled for a known invalid address + { + testResVal.Store(`{"valid": false}`) + + cfg := conf.MailerConfiguration{ + EmailValidationExtended: false, + EmailValidationServiceURL: "", + EmailValidationServiceHeaders: "", + } + if err := cfg.Validate(); err != nil { + t.Fatal(err) + } + + ev := newEmailValidator(cfg) + err := ev.Validate(ctx, "test@gmail.com") + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + } + + // Return err from service when + // extended reports invalid + // service is disabled for a known invalid address + { + testResVal.Store(`{"valid": true}`) + cfg := conf.MailerConfiguration{ + EmailValidationExtended: true, + EmailValidationServiceURL: "", + EmailValidationServiceHeaders: "", + } + if err := cfg.Validate(); err != nil { + t.Fatal(err) + } + + ev := newEmailValidator(cfg) + err := ev.Validate(ctx, "test@gmail.com") + if err == nil { + t.Fatal("exp non-nil err") + } + } + + // Return err from service when + // extended reports invalid + // service reports valid + { + testResVal.Store(`{"valid": true}`) + cfg := conf.MailerConfiguration{ + EmailValidationExtended: true, + EmailValidationServiceURL: ts.URL, + EmailValidationServiceHeaders: testHeaders, + } + if err := cfg.Validate(); err != nil { + t.Fatal(err) + } + + ev := newEmailValidator(cfg) + err := ev.Validate(ctx, "test@gmail.com") + if err == nil { + t.Fatal("exp non-nil err") + } + } + + // Return err from service when + // extended reports valid + // service reports invalid + { + testResVal.Store(`{"valid": false}`) + cfg := conf.MailerConfiguration{ + EmailValidationExtended: true, + EmailValidationServiceURL: ts.URL, + EmailValidationServiceHeaders: testHeaders, + } + if err := cfg.Validate(); err != nil { + t.Fatal(err) + } + + ev := newEmailValidator(cfg) + err := ev.Validate(ctx, "chris.stockton@supabase.io") + if err == nil { + t.Fatal("exp non-nil err") + } + } + + // Return err from service when + // extended reports invalid + // service reports invalid + { + testResVal.Store(`{"valid": false}`) + + cfg := conf.MailerConfiguration{ + EmailValidationExtended: false, + EmailValidationServiceURL: ts.URL, + EmailValidationServiceHeaders: testHeaders, + } + if err := cfg.Validate(); err != nil { + t.Fatal(err) + } + + ev := newEmailValidator(cfg) + err := ev.Validate(ctx, "test@gmail.com") + if err == nil { + t.Fatal("exp non-nil err") + } + } +} + +func TestValidateEmailExtended(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Second*60) + defer cancel() + + cases := []struct { + email string + timeout time.Duration + err string + }{ + // valid (has mx record) + {email: "a@supabase.io"}, + {email: "support@supabase.io"}, + {email: "chris.stockton@supabase.io"}, + + // bad format + {email: "", err: "invalid_email_format"}, + {email: "io", err: "invalid_email_format"}, + {email: "supabase.io", err: "invalid_email_format"}, + {email: "@supabase.io", err: "invalid_email_format"}, + {email: "test@.supabase.io", err: "invalid_email_format"}, + + // invalid: valid mx records, but invalid and often typed + // (invalidEmailMap) + {email: "test@email.com", err: "invalid_email_address"}, + {email: "test@gmail.com", err: "invalid_email_address"}, + {email: "test@test.com", err: "invalid_email_dns"}, + + // very common typo + {email: "test@gamil.com", err: "invalid_email_dns"}, + + // invalid: valid mx records, but invalid and often typed + // (invalidHostMap) + {email: "a@example.com", err: "invalid_email_dns"}, + {email: "a@example.net", err: "invalid_email_dns"}, + {email: "a@example.org", err: "invalid_email_dns"}, + + // invalid: no mx records + {email: "a@test", err: "invalid_email_dns"}, + {email: "test@local", err: "invalid_email_dns"}, + {email: "test@test.local", err: "invalid_email_dns"}, + {email: "test@example", err: "invalid_email_dns"}, + {email: "test@invalid", err: "invalid_email_dns"}, + + // valid but not actually valid and typed a lot + {email: "a@invalid", err: "invalid_email_dns"}, + {email: "a@a.invalid", err: "invalid_email_dns"}, + {email: "test@invalid", err: "invalid_email_dns"}, + + // various invalid emails + {email: "test@test.localhost", err: "invalid_email_dns"}, + {email: "test@invalid.example.com", err: "invalid_email_dns"}, + {email: "test@no.such.email.host.supabase.io", err: "invalid_email_dns"}, + + // this low timeout should simulate a dns timeout, which should + // not be treated as an invalid email. + {email: "validemail@probablyaaaaaaaanotarealdomain.com", + timeout: time.Millisecond}, + + // likewise for a valid email + {email: "support@supabase.io", timeout: time.Millisecond}, + } + + cfg := conf.MailerConfiguration{ + EmailValidationExtended: true, + EmailValidationServiceURL: "", + EmailValidationServiceHeaders: "", + } + ev := newEmailValidator(cfg) + + for idx, tc := range cases { + func(timeout time.Duration) { + if timeout == 0 { + timeout = validateEmailTimeout + } + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + now := time.Now() + err := ev.Validate(ctx, tc.email) + dur := time.Since(now) + if max := timeout + (time.Millisecond * 50); max < dur { + t.Fatal("timeout was not respected") + } + + t.Logf("tc #%v - email %q", idx, tc.email) + if tc.err != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.err) + return + } + require.NoError(t, err) + + }(tc.timeout) + } +} diff --git a/auth_v2.169.0/internal/metering/record.go b/auth_v2.169.0/internal/metering/record.go new file mode 100644 index 0000000..d9f9c5c --- /dev/null +++ b/auth_v2.169.0/internal/metering/record.go @@ -0,0 +1,17 @@ +package metering + +import ( + "github.com/gofrs/uuid" + "github.com/sirupsen/logrus" +) + +var logger = logrus.StandardLogger().WithField("metering", true) + +func RecordLogin(loginType string, userID uuid.UUID) { + logger.WithFields(logrus.Fields{ + "action": "login", + "login_method": loginType, + "instance_id": uuid.Nil.String(), + "user_id": userID.String(), + }).Info("Login") +} diff --git a/auth_v2.169.0/internal/models/amr.go b/auth_v2.169.0/internal/models/amr.go new file mode 100644 index 0000000..fdfd883 --- /dev/null +++ b/auth_v2.169.0/internal/models/amr.go @@ -0,0 +1,43 @@ +package models + +import ( + "time" + + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/storage" +) + +type AMRClaim struct { + ID uuid.UUID `json:"id" db:"id"` + SessionID uuid.UUID `json:"session_id" db:"session_id"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + AuthenticationMethod *string `json:"authentication_method" db:"authentication_method"` +} + +func (AMRClaim) TableName() string { + tableName := "mfa_amr_claims" + return tableName +} + +func (cl *AMRClaim) IsAAL2Claim() bool { + return *cl.AuthenticationMethod == TOTPSignIn.String() || *cl.AuthenticationMethod == MFAPhone.String() || *cl.AuthenticationMethod == MFAWebAuthn.String() +} + +func AddClaimToSession(tx *storage.Connection, sessionId uuid.UUID, authenticationMethod AuthenticationMethod) error { + id := uuid.Must(uuid.NewV4()) + + currentTime := time.Now() + return tx.RawQuery("INSERT INTO "+(&pop.Model{Value: AMRClaim{}}).TableName()+ + `(id, session_id, created_at, updated_at, authentication_method) values (?, ?, ?, ?, ?) + ON CONFLICT ON CONSTRAINT mfa_amr_claims_session_id_authentication_method_pkey + DO UPDATE SET updated_at = ?;`, id, sessionId, currentTime, currentTime, authenticationMethod.String(), currentTime).Exec() +} + +func (a *AMRClaim) GetAuthenticationMethod() string { + if a.AuthenticationMethod == nil { + return "" + } + return *(a.AuthenticationMethod) +} diff --git a/auth_v2.169.0/internal/models/audit_log_entry.go b/auth_v2.169.0/internal/models/audit_log_entry.go new file mode 100644 index 0000000..5bbc9b0 --- /dev/null +++ b/auth_v2.169.0/internal/models/audit_log_entry.go @@ -0,0 +1,166 @@ +package models + +import ( + "bytes" + "fmt" + "net/http" + "time" + + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" +) + +type AuditAction string +type auditLogType string + +const ( + LoginAction AuditAction = "login" + LogoutAction AuditAction = "logout" + InviteAcceptedAction AuditAction = "invite_accepted" + UserSignedUpAction AuditAction = "user_signedup" + UserInvitedAction AuditAction = "user_invited" + UserDeletedAction AuditAction = "user_deleted" + UserModifiedAction AuditAction = "user_modified" + UserRecoveryRequestedAction AuditAction = "user_recovery_requested" + UserReauthenticateAction AuditAction = "user_reauthenticate_requested" + UserConfirmationRequestedAction AuditAction = "user_confirmation_requested" + UserRepeatedSignUpAction AuditAction = "user_repeated_signup" + UserUpdatePasswordAction AuditAction = "user_updated_password" + TokenRevokedAction AuditAction = "token_revoked" + TokenRefreshedAction AuditAction = "token_refreshed" + GenerateRecoveryCodesAction AuditAction = "generate_recovery_codes" + EnrollFactorAction AuditAction = "factor_in_progress" + UnenrollFactorAction AuditAction = "factor_unenrolled" + CreateChallengeAction AuditAction = "challenge_created" + VerifyFactorAction AuditAction = "verification_attempted" + DeleteFactorAction AuditAction = "factor_deleted" + DeleteRecoveryCodesAction AuditAction = "recovery_codes_deleted" + UpdateFactorAction AuditAction = "factor_updated" + MFACodeLoginAction AuditAction = "mfa_code_login" + IdentityUnlinkAction AuditAction = "identity_unlinked" + + account auditLogType = "account" + team auditLogType = "team" + token auditLogType = "token" + user auditLogType = "user" + factor auditLogType = "factor" + recoveryCodes auditLogType = "recovery_codes" +) + +var ActionLogTypeMap = map[AuditAction]auditLogType{ + LoginAction: account, + LogoutAction: account, + InviteAcceptedAction: account, + UserSignedUpAction: team, + UserInvitedAction: team, + UserDeletedAction: team, + TokenRevokedAction: token, + TokenRefreshedAction: token, + UserModifiedAction: user, + UserRecoveryRequestedAction: user, + UserConfirmationRequestedAction: user, + UserRepeatedSignUpAction: user, + UserUpdatePasswordAction: user, + GenerateRecoveryCodesAction: user, + EnrollFactorAction: factor, + UnenrollFactorAction: factor, + CreateChallengeAction: factor, + VerifyFactorAction: factor, + DeleteFactorAction: factor, + UpdateFactorAction: factor, + MFACodeLoginAction: factor, + DeleteRecoveryCodesAction: recoveryCodes, +} + +// AuditLogEntry is the database model for audit log entries. +type AuditLogEntry struct { + ID uuid.UUID `json:"id" db:"id"` + Payload JSONMap `json:"payload" db:"payload"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + IPAddress string `json:"ip_address" db:"ip_address"` + + DONTUSEINSTANCEID uuid.UUID `json:"-" db:"instance_id"` +} + +func (AuditLogEntry) TableName() string { + tableName := "audit_log_entries" + return tableName +} + +func NewAuditLogEntry(r *http.Request, tx *storage.Connection, actor *User, action AuditAction, ipAddress string, traits map[string]interface{}) error { + id := uuid.Must(uuid.NewV4()) + + username := actor.GetEmail() + + if actor.GetPhone() != "" { + username = actor.GetPhone() + } + + payload := map[string]interface{}{ + "actor_id": actor.ID, + "actor_via_sso": actor.IsSSOUser, + "actor_username": username, + "action": action, + "log_type": ActionLogTypeMap[action], + } + l := AuditLogEntry{ + ID: id, + Payload: JSONMap(payload), + IPAddress: ipAddress, + } + + observability.LogEntrySetFields(r, logrus.Fields{ + "auth_event": logrus.Fields(payload), + }) + + if name, ok := actor.UserMetaData["full_name"]; ok { + l.Payload["actor_name"] = name + } + + if traits != nil { + l.Payload["traits"] = traits + } + + if err := tx.Create(&l); err != nil { + return errors.Wrap(err, "Database error creating audit log entry") + } + + return nil +} + +func FindAuditLogEntries(tx *storage.Connection, filterColumns []string, filterValue string, pageParams *Pagination) ([]*AuditLogEntry, error) { + q := tx.Q().Order("created_at desc").Where("instance_id = ?", uuid.Nil) + + if len(filterColumns) > 0 && filterValue != "" { + lf := "%" + filterValue + "%" + + builder := bytes.NewBufferString("(") + values := make([]interface{}, len(filterColumns)) + + for idx, col := range filterColumns { + builder.WriteString(fmt.Sprintf("payload->>'%s' ILIKE ?", col)) + values[idx] = lf + + if idx+1 < len(filterColumns) { + builder.WriteString(" OR ") + } + } + builder.WriteString(")") + + q = q.Where(builder.String(), values...) + } + + logs := []*AuditLogEntry{} + var err error + if pageParams != nil { + err = q.Paginate(int(pageParams.Page), int(pageParams.PerPage)).All(&logs) // #nosec G115 + pageParams.Count = uint64(q.Paginator.TotalEntriesSize) // #nosec G115 + } else { + err = q.All(&logs) + } + + return logs, err +} diff --git a/auth_v2.169.0/internal/models/challenge.go b/auth_v2.169.0/internal/models/challenge.go new file mode 100644 index 0000000..3de5b4d --- /dev/null +++ b/auth_v2.169.0/internal/models/challenge.go @@ -0,0 +1,124 @@ +package models + +import ( + "database/sql/driver" + "fmt" + + "encoding/json" + "github.com/go-webauthn/webauthn/webauthn" + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/storage" + "time" +) + +type Challenge struct { + ID uuid.UUID `json:"challenge_id" db:"id"` + FactorID uuid.UUID `json:"factor_id" db:"factor_id"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + VerifiedAt *time.Time `json:"verified_at,omitempty" db:"verified_at"` + IPAddress string `json:"ip_address" db:"ip_address"` + Factor *Factor `json:"factor,omitempty" belongs_to:"factor"` + OtpCode string `json:"otp_code,omitempty" db:"otp_code"` + WebAuthnSessionData *WebAuthnSessionData `json:"web_authn_session_data,omitempty" db:"web_authn_session_data"` +} + +type WebAuthnSessionData struct { + *webauthn.SessionData +} + +func (s *WebAuthnSessionData) Scan(value interface{}) error { + if value == nil { + s.SessionData = nil + return nil + } + + // Handle byte and string as a precaution, in postgres driver, json/jsonb should be returned as []byte + var data []byte + switch v := value.(type) { + case []byte: + data = v + case string: + data = []byte(v) + default: + panic(fmt.Sprintf("unsupported type for web_authn_session_data: %T", value)) + } + + if len(data) == 0 { + s.SessionData = nil + return nil + } + if s.SessionData == nil { + s.SessionData = &webauthn.SessionData{} + } + return json.Unmarshal(data, s.SessionData) + +} + +func (s *WebAuthnSessionData) Value() (driver.Value, error) { + if s == nil || s.SessionData == nil { + return nil, nil + } + return json.Marshal(s.SessionData) +} + +func (ws *WebAuthnSessionData) ToChallenge(factorID uuid.UUID, ipAddress string) *Challenge { + id := uuid.Must(uuid.NewV4()) + return &Challenge{ + ID: id, + FactorID: factorID, + IPAddress: ipAddress, + WebAuthnSessionData: &WebAuthnSessionData{ + ws.SessionData, + }, + } + +} + +func (Challenge) TableName() string { + tableName := "mfa_challenges" + return tableName +} + +// Update the verification timestamp +func (c *Challenge) Verify(tx *storage.Connection) error { + now := time.Now() + c.VerifiedAt = &now + return tx.UpdateOnly(c, "verified_at") +} + +func (c *Challenge) HasExpired(expiryDuration float64) bool { + return time.Now().After(c.GetExpiryTime(expiryDuration)) +} + +func (c *Challenge) GetExpiryTime(expiryDuration float64) time.Time { + return c.CreatedAt.Add(time.Second * time.Duration(expiryDuration)) +} + +func (c *Challenge) SetOtpCode(otpCode string, encrypt bool, encryptionKeyID, encryptionKey string) error { + c.OtpCode = otpCode + if encrypt { + es, err := crypto.NewEncryptedString(c.ID.String(), []byte(otpCode), encryptionKeyID, encryptionKey) + if err != nil { + return err + } + + c.OtpCode = es.String() + } + return nil + +} + +func (c *Challenge) GetOtpCode(decryptionKeys map[string]string, encrypt bool, encryptionKeyID string) (string, bool, error) { + if es := crypto.ParseEncryptedString(c.OtpCode); es != nil { + bytes, err := es.Decrypt(c.ID.String(), decryptionKeys) + if err != nil { + return "", false, err + } + + return string(bytes), encrypt && es.ShouldReEncrypt(encryptionKeyID), nil + } + + return c.OtpCode, encrypt, nil + +} diff --git a/auth_v2.169.0/internal/models/cleanup.go b/auth_v2.169.0/internal/models/cleanup.go new file mode 100644 index 0000000..9669c8d --- /dev/null +++ b/auth_v2.169.0/internal/models/cleanup.go @@ -0,0 +1,136 @@ +package models + +import ( + "context" + "fmt" + "sync/atomic" + + "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/metric" + + "go.opentelemetry.io/otel/attribute" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" +) + +type Cleaner interface { + Clean(*storage.Connection) (int, error) +} + +type Cleanup struct { + cleanupStatements []string + + // cleanupNext holds an atomically incrementing value that determines which of + // the cleanupStatements will be run next. + cleanupNext uint32 + + // cleanupAffectedRows tracks an OpenTelemetry metric on the total number of + // cleaned up rows. + cleanupAffectedRows atomic.Int64 +} + +func NewCleanup(config *conf.GlobalConfiguration) *Cleanup { + tableUsers := User{}.TableName() + tableRefreshTokens := RefreshToken{}.TableName() + tableSessions := Session{}.TableName() + tableRelayStates := SAMLRelayState{}.TableName() + tableFlowStates := FlowState{}.TableName() + tableMFAChallenges := Challenge{}.TableName() + tableMFAFactors := Factor{}.TableName() + + c := &Cleanup{} + + // These statements intentionally use SELECT ... FOR UPDATE SKIP LOCKED + // as this makes sure that only rows that are not being used in another + // transaction are deleted. These deletes are thus very quick and + // efficient, as they don't wait on other transactions. + c.cleanupStatements = append(c.cleanupStatements, + fmt.Sprintf("delete from %q where id in (select id from %q where revoked is true and updated_at < now() - interval '24 hours' limit 100 for update skip locked);", tableRefreshTokens, tableRefreshTokens), + fmt.Sprintf("update %q set revoked = true, updated_at = now() where id in (select %q.id from %q join %q on %q.session_id = %q.id where %q.not_after < now() - interval '24 hours' and %q.revoked is false limit 100 for update skip locked);", tableRefreshTokens, tableRefreshTokens, tableRefreshTokens, tableSessions, tableRefreshTokens, tableSessions, tableSessions, tableRefreshTokens), + // sessions are deleted after 72 hours to allow refresh tokens + // to be deleted piecemeal; 10 at once so that cascades don't + // overwork the database + fmt.Sprintf("delete from %q where id in (select id from %q where not_after < now() - interval '72 hours' limit 10 for update skip locked);", tableSessions, tableSessions), + fmt.Sprintf("delete from %q where id in (select id from %q where created_at < now() - interval '24 hours' limit 100 for update skip locked);", tableRelayStates, tableRelayStates), + fmt.Sprintf("delete from %q where id in (select id from %q where created_at < now() - interval '24 hours' limit 100 for update skip locked);", tableFlowStates, tableFlowStates), + fmt.Sprintf("delete from %q where id in (select id from %q where created_at < now() - interval '24 hours' limit 100 for update skip locked);", tableMFAChallenges, tableMFAChallenges), + fmt.Sprintf("delete from %q where id in (select id from %q where created_at < now() - interval '24 hours' and status = 'unverified' limit 100 for update skip locked);", tableMFAFactors, tableMFAFactors), + ) + + if config.External.AnonymousUsers.Enabled { + // delete anonymous users older than 30 days + c.cleanupStatements = append(c.cleanupStatements, + fmt.Sprintf("delete from %q where id in (select id from %q where created_at < now() - interval '30 days' and is_anonymous is true limit 100 for update skip locked);", tableUsers, tableUsers), + ) + } + + if config.Sessions.Timebox != nil { + timeboxSeconds := int((*config.Sessions.Timebox).Seconds()) + + c.cleanupStatements = append(c.cleanupStatements, fmt.Sprintf("delete from %q where id in (select id from %q where created_at + interval '%d seconds' < now() - interval '24 hours' limit 100 for update skip locked);", tableSessions, tableSessions, timeboxSeconds)) + } + + if config.Sessions.InactivityTimeout != nil { + inactivitySeconds := int((*config.Sessions.InactivityTimeout).Seconds()) + + // delete sessions with a refreshed_at column + c.cleanupStatements = append(c.cleanupStatements, fmt.Sprintf("delete from %q where id in (select id from %q where refreshed_at is not null and refreshed_at + interval '%d seconds' < now() - interval '24 hours' limit 100 for update skip locked);", tableSessions, tableSessions, inactivitySeconds)) + + // delete sessions without a refreshed_at column by looking for + // unrevoked refresh_tokens + c.cleanupStatements = append(c.cleanupStatements, fmt.Sprintf("delete from %q where id in (select %q.id as id from %q, %q where %q.session_id = %q.id and %q.refreshed_at is null and %q.revoked is false and %q.updated_at + interval '%d seconds' < now() - interval '24 hours' limit 100 for update skip locked)", tableSessions, tableSessions, tableSessions, tableRefreshTokens, tableRefreshTokens, tableSessions, tableSessions, tableRefreshTokens, tableRefreshTokens, inactivitySeconds)) + } + + meter := otel.Meter("gotrue") + + _, err := meter.Int64ObservableCounter( + "gotrue_cleanup_affected_rows", + metric.WithDescription("Number of affected rows from cleaning up stale entities"), + metric.WithInt64Callback(func(_ context.Context, o metric.Int64Observer) error { + o.Observe(c.cleanupAffectedRows.Load()) + return nil + }), + ) + + if err != nil { + logrus.WithError(err).Error("unable to get gotrue.gotrue_cleanup_rows counter metric") + } + + return c +} + +// Cleanup removes stale entities in the database. You can call it on each +// request or as a periodic background job. It does quick lockless updates or +// deletes, has an execution timeout and acquire timeout so that cleanups do +// not affect performance of other database jobs. Note that calling this does +// not clean up the whole database, but does a small piecemeal clean up each +// time when called. +func (c *Cleanup) Clean(db *storage.Connection) (int, error) { + ctx, span := observability.Tracer("gotrue").Start(db.Context(), "database-cleanup") + defer span.End() + + affectedRows := 0 + defer span.SetAttributes(attribute.Int64("gotrue.cleanup.affected_rows", int64(affectedRows))) + + if err := db.WithContext(ctx).Transaction(func(tx *storage.Connection) error { + nextIndex := atomic.AddUint32(&c.cleanupNext, 1) % uint32(len(c.cleanupStatements)) // #nosec G115 + statement := c.cleanupStatements[nextIndex] + + count, terr := tx.RawQuery(statement).ExecWithCount() + if terr != nil { + return terr + } + + affectedRows += count + + return nil + }); err != nil { + return affectedRows, err + } + c.cleanupAffectedRows.Add(int64(affectedRows)) + + return affectedRows, nil +} diff --git a/auth_v2.169.0/internal/models/cleanup_test.go b/auth_v2.169.0/internal/models/cleanup_test.go new file mode 100644 index 0000000..618fbba --- /dev/null +++ b/auth_v2.169.0/internal/models/cleanup_test.go @@ -0,0 +1,31 @@ +package models + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage/test" +) + +func TestCleanup(t *testing.T) { + globalConfig, err := conf.LoadGlobal(modelsTestConfig) + require.NoError(t, err) + conn, err := test.SetupDBConnection(globalConfig) + require.NoError(t, err) + + timebox := 10 * time.Second + inactivityTimeout := 5 * time.Second + globalConfig.Sessions.Timebox = &timebox + globalConfig.Sessions.InactivityTimeout = &inactivityTimeout + globalConfig.External.AnonymousUsers.Enabled = true + + cleanup := NewCleanup(globalConfig) + + for i := 0; i < 100; i += 1 { + _, err := cleanup.Clean(conn) + require.NoError(t, err) + } +} diff --git a/auth_v2.169.0/internal/models/connection.go b/auth_v2.169.0/internal/models/connection.go new file mode 100644 index 0000000..80acccc --- /dev/null +++ b/auth_v2.169.0/internal/models/connection.go @@ -0,0 +1,62 @@ +package models + +import ( + "github.com/gobuffalo/pop/v6" + "github.com/supabase/auth/internal/storage" +) + +type Pagination struct { + Page uint64 + PerPage uint64 + Count uint64 +} + +func (p *Pagination) Offset() uint64 { + return (p.Page - 1) * p.PerPage +} + +type SortDirection string + +const Ascending SortDirection = "ASC" +const Descending SortDirection = "DESC" +const CreatedAt = "created_at" + +type SortParams struct { + Fields []SortField +} + +type SortField struct { + Name string + Dir SortDirection +} + +// TruncateAll deletes all data from the database, as managed by GoTrue. Not +// intended for use outside of tests. +func TruncateAll(conn *storage.Connection) error { + return conn.Transaction(func(tx *storage.Connection) error { + tables := []string{ + (&pop.Model{Value: User{}}).TableName(), + (&pop.Model{Value: Identity{}}).TableName(), + (&pop.Model{Value: RefreshToken{}}).TableName(), + (&pop.Model{Value: AuditLogEntry{}}).TableName(), + (&pop.Model{Value: Session{}}).TableName(), + (&pop.Model{Value: Factor{}}).TableName(), + (&pop.Model{Value: Challenge{}}).TableName(), + (&pop.Model{Value: AMRClaim{}}).TableName(), + (&pop.Model{Value: SSOProvider{}}).TableName(), + (&pop.Model{Value: SSODomain{}}).TableName(), + (&pop.Model{Value: SAMLProvider{}}).TableName(), + (&pop.Model{Value: SAMLRelayState{}}).TableName(), + (&pop.Model{Value: FlowState{}}).TableName(), + (&pop.Model{Value: OneTimeToken{}}).TableName(), + } + + for _, tableName := range tables { + if err := tx.RawQuery("DELETE FROM " + tableName + " CASCADE").Exec(); err != nil { + return err + } + } + + return nil + }) +} diff --git a/auth_v2.169.0/internal/models/db_test.go b/auth_v2.169.0/internal/models/db_test.go new file mode 100644 index 0000000..c3d6ab2 --- /dev/null +++ b/auth_v2.169.0/internal/models/db_test.go @@ -0,0 +1,24 @@ +package models + +import ( + "testing" + + "github.com/gobuffalo/pop/v6" + "github.com/stretchr/testify/assert" +) + +func TestTableNameNamespacing(t *testing.T) { + cases := []struct { + expected string + value interface{} + }{ + {expected: "audit_log_entries", value: []*AuditLogEntry{}}, + {expected: "refresh_tokens", value: []*RefreshToken{}}, + {expected: "users", value: []*User{}}, + } + + for _, tc := range cases { + m := &pop.Model{Value: tc.value} + assert.Equal(t, tc.expected, m.TableName()) + } +} diff --git a/auth_v2.169.0/internal/models/errors.go b/auth_v2.169.0/internal/models/errors.go new file mode 100644 index 0000000..96f8319 --- /dev/null +++ b/auth_v2.169.0/internal/models/errors.go @@ -0,0 +1,125 @@ +package models + +// IsNotFoundError returns whether an error represents a "not found" error. +func IsNotFoundError(err error) bool { + switch err.(type) { + case UserNotFoundError, *UserNotFoundError: + return true + case SessionNotFoundError, *SessionNotFoundError: + return true + case ConfirmationTokenNotFoundError, *ConfirmationTokenNotFoundError: + return true + case ConfirmationOrRecoveryTokenNotFoundError, *ConfirmationOrRecoveryTokenNotFoundError: + return true + case RefreshTokenNotFoundError, *RefreshTokenNotFoundError: + return true + case IdentityNotFoundError, *IdentityNotFoundError: + return true + case ChallengeNotFoundError, *ChallengeNotFoundError: + return true + case FactorNotFoundError, *FactorNotFoundError: + return true + case SSOProviderNotFoundError, *SSOProviderNotFoundError: + return true + case SAMLRelayStateNotFoundError, *SAMLRelayStateNotFoundError: + return true + case FlowStateNotFoundError, *FlowStateNotFoundError: + return true + case OneTimeTokenNotFoundError, *OneTimeTokenNotFoundError: + return true + } + return false +} + +type SessionNotFoundError struct{} + +func (e SessionNotFoundError) Error() string { + return "Session not found" +} + +// UserNotFoundError represents when a user is not found. +type UserNotFoundError struct{} + +func (e UserNotFoundError) Error() string { + return "User not found" +} + +// IdentityNotFoundError represents when an identity is not found. +type IdentityNotFoundError struct{} + +func (e IdentityNotFoundError) Error() string { + return "Identity not found" +} + +// ConfirmationOrRecoveryTokenNotFoundError represents when a confirmation or recovery token is not found. +type ConfirmationOrRecoveryTokenNotFoundError struct{} + +func (e ConfirmationOrRecoveryTokenNotFoundError) Error() string { + return "Confirmation or Recovery Token not found" +} + +// ConfirmationTokenNotFoundError represents when a confirmation token is not found. +type ConfirmationTokenNotFoundError struct{} + +func (e ConfirmationTokenNotFoundError) Error() string { + return "Confirmation Token not found" +} + +// RefreshTokenNotFoundError represents when a refresh token is not found. +type RefreshTokenNotFoundError struct{} + +func (e RefreshTokenNotFoundError) Error() string { + return "Refresh Token not found" +} + +// FactorNotFoundError represents when a user is not found. +type FactorNotFoundError struct{} + +func (e FactorNotFoundError) Error() string { + return "Factor not found" +} + +// ChallengeNotFoundError represents when a user is not found. +type ChallengeNotFoundError struct{} + +func (e ChallengeNotFoundError) Error() string { + return "Challenge not found" +} + +// SSOProviderNotFoundError represents an error when a SSO Provider can't be +// found. +type SSOProviderNotFoundError struct{} + +func (e SSOProviderNotFoundError) Error() string { + return "SSO Identity Provider not found" +} + +// SAMLRelayStateNotFoundError represents an error when a SAML relay state +// can't be found. +type SAMLRelayStateNotFoundError struct{} + +func (e SAMLRelayStateNotFoundError) Error() string { + return "SAML RelayState not found" +} + +// FlowStateNotFoundError represents an error when an FlowState can't be +// found. +type FlowStateNotFoundError struct{} + +func (e FlowStateNotFoundError) Error() string { + return "Flow State not found" +} + +func IsUniqueConstraintViolatedError(err error) bool { + switch err.(type) { + case UserEmailUniqueConflictError, *UserEmailUniqueConflictError: + return true + } + return false +} + +type UserEmailUniqueConflictError struct{} + +func (e UserEmailUniqueConflictError) Error() string { + return "User email unique constraint violated" +} diff --git a/auth_v2.169.0/internal/models/factor.go b/auth_v2.169.0/internal/models/factor.go new file mode 100644 index 0000000..a88874d --- /dev/null +++ b/auth_v2.169.0/internal/models/factor.go @@ -0,0 +1,398 @@ +package models + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/go-webauthn/webauthn/webauthn" + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/storage" +) + +type FactorState int + +const ( + FactorStateUnverified FactorState = iota + FactorStateVerified +) + +func (factorState FactorState) String() string { + switch factorState { + case FactorStateUnverified: + return "unverified" + case FactorStateVerified: + return "verified" + } + return "" +} + +const TOTP = "totp" +const Phone = "phone" +const WebAuthn = "webauthn" + +type AuthenticationMethod int + +const ( + OAuth AuthenticationMethod = iota + PasswordGrant + OTP + TOTPSignIn + MFAPhone + MFAWebAuthn + SSOSAML + Recovery + Invite + MagicLink + EmailSignup + EmailChange + TokenRefresh + Anonymous +) + +func (authMethod AuthenticationMethod) String() string { + switch authMethod { + case OAuth: + return "oauth" + case PasswordGrant: + return "password" + case OTP: + return "otp" + case TOTPSignIn: + return "totp" + case Recovery: + return "recovery" + case Invite: + return "invite" + case SSOSAML: + return "sso/saml" + case MagicLink: + return "magiclink" + case EmailSignup: + return "email/signup" + case EmailChange: + return "email_change" + case TokenRefresh: + return "token_refresh" + case Anonymous: + return "anonymous" + case MFAPhone: + return "mfa/phone" + case MFAWebAuthn: + return "mfa/webauthn" + } + return "" +} + +func ParseAuthenticationMethod(authMethod string) (AuthenticationMethod, error) { + if strings.HasSuffix(authMethod, "signup") { + authMethod = "email/signup" + } + switch authMethod { + case "oauth": + return OAuth, nil + case "password": + return PasswordGrant, nil + case "otp": + return OTP, nil + case "totp": + return TOTPSignIn, nil + case "recovery": + return Recovery, nil + case "invite": + return Invite, nil + case "sso/saml": + return SSOSAML, nil + case "magiclink": + return MagicLink, nil + case "email/signup": + return EmailSignup, nil + case "email_change": + return EmailChange, nil + case "token_refresh": + return TokenRefresh, nil + case "mfa/sms": + return MFAPhone, nil + case "mfa/webauthn": + return MFAWebAuthn, nil + } + return 0, fmt.Errorf("unsupported authentication method %q", authMethod) +} + +type Factor struct { + ID uuid.UUID `json:"id" db:"id"` + // TODO: Consider removing this nested user field. We don't use it. + User User `json:"-" belongs_to:"user"` + UserID uuid.UUID `json:"-" db:"user_id"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + Status string `json:"status" db:"status"` + FriendlyName string `json:"friendly_name,omitempty" db:"friendly_name"` + Secret string `json:"-" db:"secret"` + FactorType string `json:"factor_type" db:"factor_type"` + Challenge []Challenge `json:"-" has_many:"challenges"` + Phone storage.NullString `json:"phone" db:"phone"` + LastChallengedAt *time.Time `json:"last_challenged_at" db:"last_challenged_at"` + WebAuthnCredential *WebAuthnCredential `json:"-" db:"web_authn_credential"` + WebAuthnAAGUID *uuid.UUID `json:"web_authn_aaguid,omitempty" db:"web_authn_aaguid"` +} + +type WebAuthnCredential struct { + webauthn.Credential +} + +func (wc *WebAuthnCredential) Value() (driver.Value, error) { + if wc == nil { + return nil, nil + } + return json.Marshal(wc) +} + +func (wc *WebAuthnCredential) Scan(value interface{}) error { + if value == nil { + wc.Credential = webauthn.Credential{} + return nil + } + // Handle byte and string as a precaution, in postgres driver, json/jsonb should be returned as []byte + var data []byte + switch v := value.(type) { + case []byte: + data = v + case string: + data = []byte(v) + default: + return fmt.Errorf("unsupported type for web_authn_credential: %T", value) + } + if len(data) == 0 { + wc.Credential = webauthn.Credential{} + return nil + } + return json.Unmarshal(data, &wc.Credential) +} + +func (Factor) TableName() string { + tableName := "mfa_factors" + return tableName +} + +func NewFactor(user *User, friendlyName string, factorType string, state FactorState) *Factor { + id := uuid.Must(uuid.NewV4()) + + factor := &Factor{ + ID: id, + UserID: user.ID, + Status: state.String(), + FriendlyName: friendlyName, + FactorType: factorType, + } + return factor +} + +func NewTOTPFactor(user *User, friendlyName string) *Factor { + return NewFactor(user, friendlyName, TOTP, FactorStateUnverified) +} + +func NewPhoneFactor(user *User, phone, friendlyName string) *Factor { + factor := NewFactor(user, friendlyName, Phone, FactorStateUnverified) + factor.Phone = storage.NullString(phone) + return factor +} + +func NewWebAuthnFactor(user *User, friendlyName string) *Factor { + factor := NewFactor(user, friendlyName, WebAuthn, FactorStateUnverified) + return factor +} + +func (f *Factor) SetSecret(secret string, encrypt bool, encryptionKeyID, encryptionKey string) error { + f.Secret = secret + if encrypt { + es, err := crypto.NewEncryptedString(f.ID.String(), []byte(secret), encryptionKeyID, encryptionKey) + if err != nil { + return err + } + + f.Secret = es.String() + } + + return nil +} + +func (f *Factor) GetSecret(decryptionKeys map[string]string, encrypt bool, encryptionKeyID string) (string, bool, error) { + if es := crypto.ParseEncryptedString(f.Secret); es != nil { + bytes, err := es.Decrypt(f.ID.String(), decryptionKeys) + if err != nil { + return "", false, err + } + + return string(bytes), encrypt && es.ShouldReEncrypt(encryptionKeyID), nil + } + + return f.Secret, encrypt, nil +} + +func (f *Factor) SaveWebAuthnCredential(tx *storage.Connection, credential *webauthn.Credential) error { + f.WebAuthnCredential = &WebAuthnCredential{ + Credential: *credential, + } + + if len(credential.Authenticator.AAGUID) > 0 { + aaguidUUID, err := uuid.FromBytes(credential.Authenticator.AAGUID) + if err != nil { + return fmt.Errorf("WebAuthn authenticator AAGUID is not UUID: %w", err) + } + f.WebAuthnAAGUID = &aaguidUUID + } else { + f.WebAuthnAAGUID = nil + } + + return tx.UpdateOnly(f, "web_authn_credential", "web_authn_aaguid", "updated_at") +} + +func FindFactorByFactorID(conn *storage.Connection, factorID uuid.UUID) (*Factor, error) { + var factor Factor + err := conn.Find(&factor, factorID) + if err != nil && errors.Cause(err) == sql.ErrNoRows { + return nil, FactorNotFoundError{} + } else if err != nil { + return nil, err + } + return &factor, nil +} + +func DeleteUnverifiedFactors(tx *storage.Connection, user *User, factorType string) error { + if err := tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Factor{}}).TableName()+" WHERE user_id = ? and status = ? and factor_type = ?", user.ID, FactorStateUnverified.String(), factorType).Exec(); err != nil { + return err + } + + return nil +} + +func (f *Factor) CreateChallenge(ipAddress string) *Challenge { + id := uuid.Must(uuid.NewV4()) + challenge := &Challenge{ + ID: id, + FactorID: f.ID, + IPAddress: ipAddress, + } + + return challenge +} +func (f *Factor) WriteChallengeToDatabase(tx *storage.Connection, challenge *Challenge) error { + if challenge.FactorID != f.ID { + return errors.New("Can only write challenges that you own") + } + now := time.Now() + f.LastChallengedAt = &now + if terr := tx.Create(challenge); terr != nil { + return terr + } + if err := tx.UpdateOnly(f, "last_challenged_at"); err != nil { + return err + } + return nil +} + +func (f *Factor) CreatePhoneChallenge(ipAddress string, otpCode string, encrypt bool, encryptionKeyID, encryptionKey string) (*Challenge, error) { + phoneChallenge := f.CreateChallenge(ipAddress) + if err := phoneChallenge.SetOtpCode(otpCode, encrypt, encryptionKeyID, encryptionKey); err != nil { + return nil, err + } + return phoneChallenge, nil +} + +// UpdateFriendlyName changes the friendly name +func (f *Factor) UpdateFriendlyName(tx *storage.Connection, friendlyName string) error { + f.FriendlyName = friendlyName + return tx.UpdateOnly(f, "friendly_name", "updated_at") +} + +func (f *Factor) UpdatePhone(tx *storage.Connection, phone string) error { + f.Phone = storage.NullString(phone) + return tx.UpdateOnly(f, "phone", "updated_at") +} + +// UpdateStatus modifies the factor status +func (f *Factor) UpdateStatus(tx *storage.Connection, state FactorState) error { + f.Status = state.String() + return tx.UpdateOnly(f, "status", "updated_at") +} + +func (f *Factor) DowngradeSessionsToAAL1(tx *storage.Connection) error { + sessions, err := FindSessionsByFactorID(tx, f.ID) + if err != nil { + return err + } + for _, session := range sessions { + if err := tx.RawQuery("DELETE FROM "+(&pop.Model{Value: AMRClaim{}}).TableName()+" WHERE session_id = ? AND authentication_method = ?", session.ID, f.FactorType).Exec(); err != nil { + return err + } + } + return updateFactorAssociatedSessions(tx, f.UserID, f.ID, AAL1.String()) +} + +func (f *Factor) IsVerified() bool { + return f.Status == FactorStateVerified.String() +} + +func (f *Factor) IsUnverified() bool { + return f.Status == FactorStateUnverified.String() +} + +func (f *Factor) IsPhoneFactor() bool { + return f.FactorType == Phone +} + +func (f *Factor) FindChallengeByID(conn *storage.Connection, challengeID uuid.UUID) (*Challenge, error) { + var challenge Challenge + err := conn.Q().Where("id = ? and factor_id = ?", challengeID, f.ID).First(&challenge) + if err != nil && errors.Cause(err) == sql.ErrNoRows { + return nil, ChallengeNotFoundError{} + } else if err != nil { + return nil, err + } + return &challenge, nil +} + +func DeleteFactorsByUserId(tx *storage.Connection, userId uuid.UUID) error { + if err := tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Factor{}}).TableName()+" WHERE user_id = ?", userId).Exec(); err != nil { + return err + } + return nil +} + +func DeleteExpiredFactors(tx *storage.Connection, validityDuration time.Duration) error { + totalSeconds := int64(validityDuration / time.Second) + validityInterval := fmt.Sprintf("interval '%d seconds'", totalSeconds) + + factorTable := (&pop.Model{Value: Factor{}}).TableName() + challengeTable := (&pop.Model{Value: Challenge{}}).TableName() + + query := fmt.Sprintf(`delete from %q where status != 'verified' and not exists (select * from %q where %q.id = %q.factor_id ) and created_at + %s < current_timestamp;`, factorTable, challengeTable, factorTable, challengeTable, validityInterval) + if err := tx.RawQuery(query).Exec(); err != nil { + return err + } + return nil +} + +func (f *Factor) FindLatestUnexpiredChallenge(tx *storage.Connection, expiryDuration float64) (*Challenge, error) { + now := time.Now() + var challenge Challenge + expirationTime := now.Add(time.Duration(expiryDuration) * time.Second) + + err := tx.Where("sent_at > ? and factor_id = ?", expirationTime, f.ID). + Order("sent_at desc"). + First(&challenge) + + if err != nil && errors.Cause(err) == sql.ErrNoRows { + return nil, ChallengeNotFoundError{} + } else if err != nil { + return nil, err + } + return &challenge, nil +} diff --git a/auth_v2.169.0/internal/models/factor_test.go b/auth_v2.169.0/internal/models/factor_test.go new file mode 100644 index 0000000..614cff2 --- /dev/null +++ b/auth_v2.169.0/internal/models/factor_test.go @@ -0,0 +1,74 @@ +package models + +import ( + "encoding/json" + "testing" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/storage/test" +) + +type FactorTestSuite struct { + suite.Suite + db *storage.Connection + TestFactor *Factor +} + +func TestFactor(t *testing.T) { + globalConfig, err := conf.LoadGlobal(modelsTestConfig) + require.NoError(t, err) + conn, err := test.SetupDBConnection(globalConfig) + require.NoError(t, err) + ts := &FactorTestSuite{ + db: conn, + } + defer ts.db.Close() + suite.Run(t, ts) +} + +func (ts *FactorTestSuite) SetupTest() { + TruncateAll(ts.db) + user, err := NewUser("", "agenericemail@gmail.com", "secret", "test", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(user)) + + factor := NewTOTPFactor(user, "asimplename") + require.NoError(ts.T(), factor.SetSecret("topsecret", false, "", "")) + require.NoError(ts.T(), ts.db.Create(factor)) + ts.TestFactor = factor +} + +func (ts *FactorTestSuite) TestFindFactorByFactorID() { + n, err := FindFactorByFactorID(ts.db, ts.TestFactor.ID) + require.NoError(ts.T(), err) + require.Equal(ts.T(), ts.TestFactor.ID, n.ID) + + _, err = FindFactorByFactorID(ts.db, uuid.Nil) + require.EqualError(ts.T(), err, FactorNotFoundError{}.Error()) +} + +func (ts *FactorTestSuite) TestUpdateStatus() { + newFactorStatus := FactorStateVerified + require.NoError(ts.T(), ts.TestFactor.UpdateStatus(ts.db, newFactorStatus)) + require.Equal(ts.T(), newFactorStatus.String(), ts.TestFactor.Status) +} + +func (ts *FactorTestSuite) TestUpdateFriendlyName() { + newName := "newfactorname" + require.NoError(ts.T(), ts.TestFactor.UpdateFriendlyName(ts.db, newName)) + require.Equal(ts.T(), newName, ts.TestFactor.FriendlyName) +} + +func (ts *FactorTestSuite) TestEncodedFactorDoesNotLeakSecret() { + encodedFactor, err := json.Marshal(ts.TestFactor) + require.NoError(ts.T(), err) + + decodedFactor := Factor{} + json.Unmarshal(encodedFactor, &decodedFactor) + require.Equal(ts.T(), decodedFactor.Secret, "") +} diff --git a/auth_v2.169.0/internal/models/flow_state.go b/auth_v2.169.0/internal/models/flow_state.go new file mode 100644 index 0000000..9a770d8 --- /dev/null +++ b/auth_v2.169.0/internal/models/flow_state.go @@ -0,0 +1,169 @@ +package models + +import ( + "crypto/sha256" + "crypto/subtle" + "database/sql" + "encoding/base64" + "fmt" + "strings" + "time" + + "github.com/pkg/errors" + "github.com/supabase/auth/internal/storage" + + "github.com/gofrs/uuid" +) + +const InvalidCodeChallengeError = "code challenge does not match previously saved code verifier" +const InvalidCodeMethodError = "code challenge method not supported" + +type FlowState struct { + ID uuid.UUID `json:"id" db:"id"` + UserID *uuid.UUID `json:"user_id,omitempty" db:"user_id"` + AuthCode string `json:"auth_code" db:"auth_code"` + AuthenticationMethod string `json:"authentication_method" db:"authentication_method"` + CodeChallenge string `json:"code_challenge" db:"code_challenge"` + CodeChallengeMethod string `json:"code_challenge_method" db:"code_challenge_method"` + ProviderType string `json:"provider_type" db:"provider_type"` + ProviderAccessToken string `json:"provider_access_token" db:"provider_access_token"` + ProviderRefreshToken string `json:"provider_refresh_token" db:"provider_refresh_token"` + AuthCodeIssuedAt *time.Time `json:"auth_code_issued_at" db:"auth_code_issued_at"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` +} + +type CodeChallengeMethod int + +const ( + SHA256 CodeChallengeMethod = iota + Plain +) + +func (codeChallengeMethod CodeChallengeMethod) String() string { + switch codeChallengeMethod { + case SHA256: + return "s256" + case Plain: + return "plain" + } + return "" +} + +func ParseCodeChallengeMethod(codeChallengeMethod string) (CodeChallengeMethod, error) { + switch strings.ToLower(codeChallengeMethod) { + case "s256": + return SHA256, nil + case "plain": + return Plain, nil + } + return 0, fmt.Errorf("unsupported code_challenge method %q", codeChallengeMethod) +} + +type FlowType int + +const ( + PKCEFlow FlowType = iota + ImplicitFlow +) + +func (flowType FlowType) String() string { + switch flowType { + case PKCEFlow: + return "pkce" + case ImplicitFlow: + return "implicit" + } + return "" +} + +func (FlowState) TableName() string { + tableName := "flow_state" + return tableName +} + +func NewFlowState(providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod, userID *uuid.UUID) *FlowState { + id := uuid.Must(uuid.NewV4()) + authCode := uuid.Must(uuid.NewV4()) + flowState := &FlowState{ + ID: id, + ProviderType: providerType, + CodeChallenge: codeChallenge, + CodeChallengeMethod: codeChallengeMethod.String(), + AuthCode: authCode.String(), + AuthenticationMethod: authenticationMethod.String(), + UserID: userID, + } + return flowState +} + +func FindFlowStateByAuthCode(tx *storage.Connection, authCode string) (*FlowState, error) { + obj := &FlowState{} + if err := tx.Eager().Q().Where("auth_code = ?", authCode).First(obj); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, FlowStateNotFoundError{} + } + return nil, errors.Wrap(err, "error finding flow state") + } + + return obj, nil +} + +func FindFlowStateByID(tx *storage.Connection, id string) (*FlowState, error) { + obj := &FlowState{} + if err := tx.Eager().Q().Where("id = ?", id).First(obj); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, FlowStateNotFoundError{} + } + return nil, errors.Wrap(err, "error finding flow state") + } + + return obj, nil +} + +func FindFlowStateByUserID(tx *storage.Connection, id string, authenticationMethod AuthenticationMethod) (*FlowState, error) { + obj := &FlowState{} + if err := tx.Eager().Q().Where("user_id = ? and authentication_method = ?", id, authenticationMethod).Last(obj); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, FlowStateNotFoundError{} + } + return nil, errors.Wrap(err, "error finding flow state") + } + + return obj, nil +} + +func (f *FlowState) VerifyPKCE(codeVerifier string) error { + switch f.CodeChallengeMethod { + case SHA256.String(): + hashedCodeVerifier := sha256.Sum256([]byte(codeVerifier)) + encodedCodeVerifier := base64.RawURLEncoding.EncodeToString(hashedCodeVerifier[:]) + if subtle.ConstantTimeCompare([]byte(f.CodeChallenge), []byte(encodedCodeVerifier)) != 1 { + return errors.New(InvalidCodeChallengeError) + } + case Plain.String(): + if subtle.ConstantTimeCompare([]byte(f.CodeChallenge), []byte(codeVerifier)) != 1 { + return errors.New(InvalidCodeChallengeError) + } + default: + return errors.New(InvalidCodeMethodError) + + } + return nil +} + +func (f *FlowState) IsExpired(expiryDuration time.Duration) bool { + if f.AuthCodeIssuedAt != nil && f.AuthenticationMethod == MagicLink.String() { + return time.Now().After(f.AuthCodeIssuedAt.Add(expiryDuration)) + } + return time.Now().After(f.CreatedAt.Add(expiryDuration)) +} + +func (f *FlowState) RecordAuthCodeIssuedAtTime(tx *storage.Connection) error { + issueTime := time.Now() + f.AuthCodeIssuedAt = &issueTime + if err := tx.Update(f); err != nil { + return err + } + return nil +} diff --git a/auth_v2.169.0/internal/models/identity.go b/auth_v2.169.0/internal/models/identity.go new file mode 100644 index 0000000..c647cbc --- /dev/null +++ b/auth_v2.169.0/internal/models/identity.go @@ -0,0 +1,142 @@ +package models + +import ( + "database/sql" + "strings" + "time" + + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/storage" +) + +type Identity struct { + // returned as identity_id in JSON for backward compatibility with the interface exposed by the client library + // see https://github.com/supabase/gotrue-js/blob/c9296bbc27a2f036af55c1f33fca5930704bd021/src/lib/types.ts#L230-L240 + ID uuid.UUID `json:"identity_id" db:"id"` + // returned as id in JSON for backward compatibility with the interface exposed by the client library + // see https://github.com/supabase/gotrue-js/blob/c9296bbc27a2f036af55c1f33fca5930704bd021/src/lib/types.ts#L230-L240 + ProviderID string `json:"id" db:"provider_id"` + UserID uuid.UUID `json:"user_id" db:"user_id"` + IdentityData JSONMap `json:"identity_data,omitempty" db:"identity_data"` + Provider string `json:"provider" db:"provider"` + LastSignInAt *time.Time `json:"last_sign_in_at,omitempty" db:"last_sign_in_at"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + Email storage.NullString `json:"email,omitempty" db:"email" rw:"r"` +} + +func (Identity) TableName() string { + tableName := "identities" + return tableName +} + +// GetEmail returns the user's email as a string +func (i *Identity) GetEmail() string { + return string(i.Email) +} + +// NewIdentity returns an identity associated to the user's id. +func NewIdentity(user *User, provider string, identityData map[string]interface{}) (*Identity, error) { + providerId, ok := identityData["sub"] + if !ok { + return nil, errors.New("error missing provider id") + } + now := time.Now() + + identity := &Identity{ + ProviderID: providerId.(string), + UserID: user.ID, + IdentityData: identityData, + Provider: provider, + LastSignInAt: &now, + } + if email, ok := identityData["email"]; ok { + identity.Email = storage.NullString(email.(string)) + } + + return identity, nil +} + +func (i *Identity) BeforeCreate(tx *pop.Connection) error { + return i.BeforeUpdate(tx) +} + +func (i *Identity) BeforeUpdate(tx *pop.Connection) error { + if _, ok := i.IdentityData["email"]; ok { + i.IdentityData["email"] = strings.ToLower(i.IdentityData["email"].(string)) + } + return nil +} + +func (i *Identity) IsForSSOProvider() bool { + return strings.HasPrefix(i.Provider, "sso:") +} + +// FindIdentityById searches for an identity with the matching id and provider given. +func FindIdentityByIdAndProvider(tx *storage.Connection, providerId, provider string) (*Identity, error) { + identity := &Identity{} + if err := tx.Q().Where("provider_id = ? AND provider = ?", providerId, provider).First(identity); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, IdentityNotFoundError{} + } + return nil, errors.Wrap(err, "error finding identity") + } + return identity, nil +} + +// FindIdentitiesByUserID returns all identities associated to a user ID. +func FindIdentitiesByUserID(tx *storage.Connection, userID uuid.UUID) ([]*Identity, error) { + identities := []*Identity{} + if err := tx.Q().Where("user_id = ?", userID).All(&identities); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return identities, nil + } + return nil, errors.Wrap(err, "error finding identities") + } + return identities, nil +} + +// FindProvidersByUser returns all providers associated to a user +func FindProvidersByUser(tx *storage.Connection, user *User) ([]string, error) { + identities := []Identity{} + providerExists := map[string]bool{} + providers := make([]string, 0) + if err := tx.Q().Select("provider").Where("user_id = ?", user.ID).All(&identities); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return providers, nil + } + return nil, errors.Wrap(err, "error finding providers") + } + for _, identity := range identities { + if _, ok := providerExists[identity.Provider]; !ok { + providers = append(providers, identity.Provider) + providerExists[identity.Provider] = true + } + } + return providers, nil +} + +// UpdateIdentityData sets all identity_data from a map of updates, +// ensuring that it doesn't override attributes that are not +// in the provided map. +func (i *Identity) UpdateIdentityData(tx *storage.Connection, updates map[string]interface{}) error { + if i.IdentityData == nil { + i.IdentityData = updates + } else { + for key, value := range updates { + if value != nil { + i.IdentityData[key] = value + } else { + delete(i.IdentityData, key) + } + } + } + // pop doesn't support updates on tables with composite primary keys so we use a raw query here. + return tx.RawQuery( + "update "+(&pop.Model{Value: Identity{}}).TableName()+" set identity_data = ? where id = ?", + i.IdentityData, + i.ID, + ).Exec() +} diff --git a/auth_v2.169.0/internal/models/identity_test.go b/auth_v2.169.0/internal/models/identity_test.go new file mode 100644 index 0000000..d27d17b --- /dev/null +++ b/auth_v2.169.0/internal/models/identity_test.go @@ -0,0 +1,117 @@ +package models + +import ( + "testing" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/storage/test" +) + +type IdentityTestSuite struct { + suite.Suite + db *storage.Connection +} + +func (ts *IdentityTestSuite) SetupTest() { + TruncateAll(ts.db) +} + +func TestIdentity(t *testing.T) { + globalConfig, err := conf.LoadGlobal(modelsTestConfig) + require.NoError(t, err) + + conn, err := test.SetupDBConnection(globalConfig) + require.NoError(t, err) + + ts := &IdentityTestSuite{ + db: conn, + } + defer ts.db.Close() + + suite.Run(t, ts) +} + +func (ts *IdentityTestSuite) TestNewIdentity() { + u := ts.createUserWithEmail("test@supabase.io") + ts.Run("Test create identity with no provider id", func() { + identityData := map[string]interface{}{} + _, err := NewIdentity(u, "email", identityData) + require.Error(ts.T(), err, "Error missing provider id") + }) + + ts.Run("Test create identity successfully", func() { + identityData := map[string]interface{}{"sub": uuid.Nil.String()} + identity, err := NewIdentity(u, "email", identityData) + require.NoError(ts.T(), err) + require.Equal(ts.T(), u.ID, identity.UserID) + }) +} + +func (ts *IdentityTestSuite) TestFindUserIdentities() { + u := ts.createUserWithIdentity("test@supabase.io") + identities, err := FindIdentitiesByUserID(ts.db, u.ID) + require.NoError(ts.T(), err) + + require.Len(ts.T(), identities, 1) + +} + +func (ts *IdentityTestSuite) TestUpdateIdentityData() { + u := ts.createUserWithIdentity("test@supabase.io") + + identities, err := FindIdentitiesByUserID(ts.db, u.ID) + require.NoError(ts.T(), err) + + updates := map[string]interface{}{ + "sub": nil, + "name": nil, + "email": nil, + } + for _, identity := range identities { + err := identity.UpdateIdentityData(ts.db, updates) + require.NoError(ts.T(), err) + } + + updatedIdentities, err := FindIdentitiesByUserID(ts.db, u.ID) + require.NoError(ts.T(), err) + for _, identity := range updatedIdentities { + require.Empty(ts.T(), identity.IdentityData) + } +} + +func (ts *IdentityTestSuite) createUserWithEmail(email string) *User { + user, err := NewUser("", email, "secret", "test", nil) + require.NoError(ts.T(), err) + + err = ts.db.Create(user) + require.NoError(ts.T(), err) + + return user +} + +func (ts *IdentityTestSuite) createUserWithIdentity(email string) *User { + user, err := NewUser("", email, "secret", "test", nil) + require.NoError(ts.T(), err) + + err = ts.db.Create(user) + require.NoError(ts.T(), err) + + identityData := map[string]interface{}{ + "sub": uuid.Nil.String(), + "name": "test", + "email": email, + } + require.NoError(ts.T(), err) + + identity, err := NewIdentity(user, "email", identityData) + require.NoError(ts.T(), err) + + err = ts.db.Create(identity) + require.NoError(ts.T(), err) + + return user +} diff --git a/auth_v2.169.0/internal/models/json_map.go b/auth_v2.169.0/internal/models/json_map.go new file mode 100644 index 0000000..77cee64 --- /dev/null +++ b/auth_v2.169.0/internal/models/json_map.go @@ -0,0 +1,36 @@ +package models + +import ( + "database/sql/driver" + "encoding/json" + "errors" +) + +type JSONMap map[string]interface{} + +func (j JSONMap) Value() (driver.Value, error) { + data, err := json.Marshal(j) + if err != nil { + return driver.Value(""), err + } + return driver.Value(string(data)), nil +} + +func (j JSONMap) Scan(src interface{}) error { + var source []byte + switch v := src.(type) { + case string: + source = []byte(v) + case []byte: + source = v + case nil: + source = []byte("") + default: + return errors.New("invalid data type for JSONMap") + } + + if len(source) == 0 { + source = []byte("{}") + } + return json.Unmarshal(source, &j) +} diff --git a/auth_v2.169.0/internal/models/linking.go b/auth_v2.169.0/internal/models/linking.go new file mode 100644 index 0000000..ca794bc --- /dev/null +++ b/auth_v2.169.0/internal/models/linking.go @@ -0,0 +1,203 @@ +package models + +import ( + "strings" + + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" +) + +// GetAccountLinkingDomain returns a string that describes the account linking +// domain. An account linking domain describes a set of Identity entities that +// _should_ generally fall under the same User entity. It's just a runtime +// string, and is not typically persisted in the database. This value can vary +// across time. +func GetAccountLinkingDomain(provider string) string { + if strings.HasPrefix(provider, "sso:") { + // when the provider ID is a SSO provider, then the linking + // domain is the provider itself i.e. there can only be one + // user + identity per identity provider + return provider + } + + // otherwise, the linking domain is the default linking domain that + // links all accounts + return "default" +} + +type AccountLinkingDecision = int + +const ( + AccountExists AccountLinkingDecision = iota + CreateAccount + LinkAccount + MultipleAccounts +) + +type AccountLinkingResult struct { + Decision AccountLinkingDecision + User *User + Identities []*Identity + LinkingDomain string + CandidateEmail provider.Email +} + +// DetermineAccountLinking uses the provided data and database state to compute a decision on whether: +// - A new User should be created (CreateAccount) +// - A new Identity should be created (LinkAccount) with a UserID pointing to an existing user account +// - Nothing should be done (AccountExists) +// - It's not possible to decide due to data inconsistency (MultipleAccounts) and the caller should decide +// +// Errors signal failure in processing only, like database access errors. +func DetermineAccountLinking(tx *storage.Connection, config *conf.GlobalConfiguration, emails []provider.Email, aud, providerName, sub string) (AccountLinkingResult, error) { + var verifiedEmails []string + var candidateEmail provider.Email + for _, email := range emails { + if email.Verified || config.Mailer.Autoconfirm { + verifiedEmails = append(verifiedEmails, strings.ToLower(email.Email)) + } + if email.Primary { + candidateEmail = email + candidateEmail.Email = strings.ToLower(email.Email) + } + } + + if identity, terr := FindIdentityByIdAndProvider(tx, sub, providerName); terr == nil { + // account exists + + var user *User + if user, terr = FindUserByID(tx, identity.UserID); terr != nil { + return AccountLinkingResult{}, terr + } + + // we overwrite the email with the existing user's email since the user + // could have an empty email + candidateEmail.Email = user.GetEmail() + return AccountLinkingResult{ + Decision: AccountExists, + User: user, + Identities: []*Identity{identity}, + LinkingDomain: GetAccountLinkingDomain(providerName), + CandidateEmail: candidateEmail, + }, nil + } else if !IsNotFoundError(terr) { + return AccountLinkingResult{}, terr + } + + // the identity does not exist, so we need to check if we should create a new account + // or link to an existing one + + // this is the linking domain for the new identity + candidateLinkingDomain := GetAccountLinkingDomain(providerName) + if len(verifiedEmails) == 0 { + // if there are no verified emails, we always decide to create a new account + user, terr := IsDuplicatedEmail(tx, candidateEmail.Email, aud, nil) + if terr != nil { + return AccountLinkingResult{}, terr + } + if user != nil { + candidateEmail.Email = "" + } + return AccountLinkingResult{ + Decision: CreateAccount, + LinkingDomain: candidateLinkingDomain, + CandidateEmail: candidateEmail, + }, nil + } + + var similarIdentities []*Identity + var similarUsers []*User + // look for similar identities and users based on email + if terr := tx.Q().Eager().Where("email = any (?)", verifiedEmails).All(&similarIdentities); terr != nil { + return AccountLinkingResult{}, terr + } + + if !strings.HasPrefix(providerName, "sso:") { + // there can be multiple user accounts with the same email when is_sso_user is true + // so we just do not consider those similar user accounts + if terr := tx.Q().Eager().Where("email = any (?) and is_sso_user = false", verifiedEmails).All(&similarUsers); terr != nil { + return AccountLinkingResult{}, terr + } + } + + // Need to check if the new identity should be assigned to an + // existing user or to create a new user, according to the automatic + // linking rules + var linkingIdentities []*Identity + + // now let's see if there are any existing and similar identities in + // the same linking domain + for _, identity := range similarIdentities { + if GetAccountLinkingDomain(identity.Provider) == candidateLinkingDomain { + linkingIdentities = append(linkingIdentities, identity) + } + } + + if len(linkingIdentities) == 0 { + if len(similarUsers) == 1 { + // no similarIdentities but a user with the same email exists + // so we link this new identity to the user + // TODO: Backfill the missing identity for the user + return AccountLinkingResult{ + Decision: LinkAccount, + User: similarUsers[0], + Identities: linkingIdentities, + LinkingDomain: candidateLinkingDomain, + CandidateEmail: candidateEmail, + }, nil + } else if len(similarUsers) > 1 { + // this shouldn't happen since there is a partial unique index on (email and is_sso_user = false) + return AccountLinkingResult{ + Decision: MultipleAccounts, + Identities: linkingIdentities, + LinkingDomain: candidateLinkingDomain, + CandidateEmail: candidateEmail, + }, nil + } else { + // there are no identities in the linking domain, we have to + // create a new identity and new user + return AccountLinkingResult{ + Decision: CreateAccount, + LinkingDomain: candidateLinkingDomain, + CandidateEmail: candidateEmail, + }, nil + } + } + + // there is at least one identity in the linking domain let's do a + // sanity check to see if all of the identities in the domain share the + // same user ID + linkingUserId := linkingIdentities[0].UserID + for _, identity := range linkingIdentities { + if identity.UserID != linkingUserId { + // ok this linking domain has more than one user account + // caller should decide what to do + + return AccountLinkingResult{ + Decision: MultipleAccounts, + Identities: linkingIdentities, + LinkingDomain: candidateLinkingDomain, + CandidateEmail: candidateEmail, + }, nil + } + } + + // there's only one user ID in this linking domain, we can go on and + // create a new identity and link it to the existing account + + var user *User + var terr error + + if user, terr = FindUserByID(tx, linkingUserId); terr != nil { + return AccountLinkingResult{}, terr + } + + return AccountLinkingResult{ + Decision: LinkAccount, + User: user, + Identities: linkingIdentities, + LinkingDomain: candidateLinkingDomain, + CandidateEmail: candidateEmail, + }, nil +} diff --git a/auth_v2.169.0/internal/models/linking_test.go b/auth_v2.169.0/internal/models/linking_test.go new file mode 100644 index 0000000..05d4a8c --- /dev/null +++ b/auth_v2.169.0/internal/models/linking_test.go @@ -0,0 +1,314 @@ +package models + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/storage/test" +) + +type AccountLinkingTestSuite struct { + suite.Suite + + config *conf.GlobalConfiguration + db *storage.Connection +} + +func (ts *AccountLinkingTestSuite) SetupTest() { + TruncateAll(ts.db) +} + +func TestAccountLinking(t *testing.T) { + globalConfig, err := conf.LoadGlobal(modelsTestConfig) + require.NoError(t, err) + + conn, err := test.SetupDBConnection(globalConfig) + require.NoError(t, err) + + ts := &AccountLinkingTestSuite{ + config: globalConfig, + db: conn, + } + defer ts.db.Close() + + suite.Run(t, ts) +} + +func (ts *AccountLinkingTestSuite) TestCreateAccountDecisionNoAccounts() { + // when there are no accounts in the system -- conventional provider + testEmail := provider.Email{ + Email: "test@example.com", + Verified: true, + Primary: true, + } + decision, err := DetermineAccountLinking(ts.db, ts.config, []provider.Email{testEmail}, ts.config.JWT.Aud, "provider", "abcdefgh") + require.NoError(ts.T(), err) + + require.Equal(ts.T(), decision.Decision, CreateAccount) + + // when there are no accounts in the system -- SSO provider + decision, err = DetermineAccountLinking(ts.db, ts.config, []provider.Email{testEmail}, ts.config.JWT.Aud, "sso:f06f9e3d-ff92-4c47-a179-7acf1fda6387", "abcdefgh") + require.NoError(ts.T(), err) + + require.Equal(ts.T(), decision.Decision, CreateAccount) +} + +func (ts *AccountLinkingTestSuite) TestCreateAccountDecisionWithAccounts() { + userA, err := NewUser("", "test@example.com", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(userA)) + identityA, err := NewIdentity(userA, "provider", map[string]interface{}{ + "sub": userA.ID.String(), + "email": "test@example.com", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(identityA)) + + userB, err := NewUser("", "test@samltest.id", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(userB)) + + ssoProvider := "sso:f06f9e3d-ff92-4c47-a179-7acf1fda6387" + identityB, err := NewIdentity(userB, ssoProvider, map[string]interface{}{ + "sub": userB.ID.String(), + "email": "test@samltest.id", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(identityB)) + + // when the email doesn't exist in the system -- conventional provider + decision, err := DetermineAccountLinking(ts.db, ts.config, []provider.Email{ + { + Email: "other@example.com", + Verified: true, + Primary: true, + }, + }, ts.config.JWT.Aud, "provider", "abcdefgh") + require.NoError(ts.T(), err) + + require.Equal(ts.T(), decision.Decision, CreateAccount) + require.Equal(ts.T(), decision.LinkingDomain, "default") + + // when looking for an email that doesn't exist in the SSO linking domain + decision, err = DetermineAccountLinking(ts.db, ts.config, []provider.Email{ + { + Email: "other@samltest.id", + Verified: true, + Primary: true, + }, + }, ts.config.JWT.Aud, ssoProvider, "abcdefgh") + require.NoError(ts.T(), err) + + require.Equal(ts.T(), decision.Decision, CreateAccount) + require.Equal(ts.T(), decision.LinkingDomain, ssoProvider) +} + +func (ts *AccountLinkingTestSuite) TestAccountExists() { + userA, err := NewUser("", "test@example.com", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(userA)) + identityA, err := NewIdentity(userA, "provider", map[string]interface{}{ + "sub": userA.ID.String(), + "email": "test@example.com", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(identityA)) + + decision, err := DetermineAccountLinking(ts.db, ts.config, []provider.Email{ + { + Email: "test@example.com", + Verified: true, + Primary: true, + }, + }, ts.config.JWT.Aud, "provider", userA.ID.String()) + require.NoError(ts.T(), err) + + require.Equal(ts.T(), decision.Decision, AccountExists) + require.Equal(ts.T(), decision.User.ID, userA.ID) +} + +func (ts *AccountLinkingTestSuite) TestLinkingScenarios() { + userA, err := NewUser("", "test@example.com", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(userA)) + identityA, err := NewIdentity(userA, "provider", map[string]interface{}{ + "sub": userA.ID.String(), + "email": "test@example.com", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(identityA)) + + userB, err := NewUser("", "test@samltest.id", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(userB)) + + identityB, err := NewIdentity(userB, "sso:f06f9e3d-ff92-4c47-a179-7acf1fda6387", map[string]interface{}{ + "sub": userB.ID.String(), + "email": "test@samltest.id", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(identityB)) + + cases := []struct { + desc string + email provider.Email + sub string + provider string + decision AccountLinkingResult + }{ + { + // link decision because the below described identity is in the default linking domain but uses "other-provider" instead of "provder" + desc: "same email address", + email: provider.Email{ + Email: "test@example.com", + Verified: true, + Primary: true, + }, + sub: userA.ID.String(), + provider: "other-provider", + decision: AccountLinkingResult{ + Decision: LinkAccount, + User: userA, + LinkingDomain: "default", + CandidateEmail: provider.Email{ + Email: "test@example.com", + Verified: true, + Primary: true, + }, + }, + }, + { + desc: "same email address in uppercase", + email: provider.Email{ + Email: "TEST@example.com", + Verified: true, + Primary: true, + }, + sub: userA.ID.String(), + provider: "other-provider", + decision: AccountLinkingResult{ + Decision: LinkAccount, + User: userA, + LinkingDomain: "default", + CandidateEmail: provider.Email{ + // expected email should be case insensitive + Email: "test@example.com", + Verified: true, + Primary: true, + }, + }, + }, + { + desc: "no link decision because the SSO linking domain is scoped to the provider unique ID", + email: provider.Email{ + Email: "test@samltest.id", + Verified: true, + Primary: true, + }, + sub: userB.ID.String(), + provider: "sso:f06f9e3d-ff92-4c47-a179-7acf1fda6387", + // decision: AccountExists, + decision: AccountLinkingResult{ + Decision: AccountExists, + User: userB, + LinkingDomain: "sso:f06f9e3d-ff92-4c47-a179-7acf1fda6387", + CandidateEmail: provider.Email{ + Email: "test@samltest.id", + Verified: true, + Primary: true, + }, + }, + }, + { + desc: "create account with empty email because email is unverified and user exists", + email: provider.Email{ + Email: "test@example.com", + Verified: false, + Primary: true, + }, + sub: userA.ID.String(), + provider: "other-provider", + decision: AccountLinkingResult{ + Decision: CreateAccount, + LinkingDomain: "default", + CandidateEmail: provider.Email{ + Email: "", + Verified: false, + Primary: true, + }, + }, + }, + { + desc: "create account because email is unverified and user doesn't exist", + email: provider.Email{ + Email: "other@example.com", + Verified: false, + Primary: true, + }, + sub: "000000000", + provider: "other-provider", + decision: AccountLinkingResult{ + Decision: CreateAccount, + LinkingDomain: "default", + CandidateEmail: provider.Email{ + Email: "other@example.com", + Verified: false, + Primary: true, + }, + }, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + decision, err := DetermineAccountLinking(ts.db, ts.config, []provider.Email{c.email}, ts.config.JWT.Aud, c.provider, c.sub) + require.NoError(ts.T(), err) + require.Equal(ts.T(), c.decision.Decision, decision.Decision) + require.Equal(ts.T(), c.decision.LinkingDomain, decision.LinkingDomain) + require.Equal(ts.T(), c.decision.CandidateEmail.Email, decision.CandidateEmail.Email) + require.Equal(ts.T(), c.decision.CandidateEmail.Verified, decision.CandidateEmail.Verified) + require.Equal(ts.T(), c.decision.CandidateEmail.Primary, decision.CandidateEmail.Primary) + }) + } + +} + +func (ts *AccountLinkingTestSuite) TestMultipleAccounts() { + userA, err := NewUser("", "test@example.com", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(userA)) + identityA, err := NewIdentity(userA, "provider", map[string]interface{}{ + "sub": userA.ID.String(), + "email": "test@example.com", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(identityA)) + + userB, err := NewUser("", "test-b@example.com", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(userB)) + identityB, err := NewIdentity(userB, "provider", map[string]interface{}{ + "sub": userB.ID.String(), + "email": "test@example.com", // intentionally same as userA + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(identityB)) + + // decision is multiple accounts because there are two distinct + // identities in the same "default" linking domain with the same email + // address pointing to two different user accounts + decision, err := DetermineAccountLinking(ts.db, ts.config, []provider.Email{ + { + Email: "test@example.com", + Verified: true, + Primary: true, + }, + }, ts.config.JWT.Aud, "provider", "abcdefgh") + require.NoError(ts.T(), err) + + require.Equal(ts.T(), decision.Decision, MultipleAccounts) +} diff --git a/auth_v2.169.0/internal/models/one_time_token.go b/auth_v2.169.0/internal/models/one_time_token.go new file mode 100644 index 0000000..3077647 --- /dev/null +++ b/auth_v2.169.0/internal/models/one_time_token.go @@ -0,0 +1,286 @@ +package models + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "strings" + "time" + + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/storage" +) + +type OneTimeTokenType int + +const ( + ConfirmationToken OneTimeTokenType = iota + ReauthenticationToken + RecoveryToken + EmailChangeTokenNew + EmailChangeTokenCurrent + PhoneChangeToken +) + +func (t OneTimeTokenType) String() string { + switch t { + case ConfirmationToken: + return "confirmation_token" + + case ReauthenticationToken: + return "reauthentication_token" + + case RecoveryToken: + return "recovery_token" + + case EmailChangeTokenNew: + return "email_change_token_new" + + case EmailChangeTokenCurrent: + return "email_change_token_current" + + case PhoneChangeToken: + return "phone_change_token" + + default: + panic("OneTimeToken: unreachable case") + } +} + +func ParseOneTimeTokenType(s string) (OneTimeTokenType, error) { + switch s { + case "confirmation_token": + return ConfirmationToken, nil + + case "reauthentication_token": + return ReauthenticationToken, nil + + case "recovery_token": + return RecoveryToken, nil + + case "email_change_token_new": + return EmailChangeTokenNew, nil + + case "email_change_token_current": + return EmailChangeTokenCurrent, nil + + case "phone_change_token": + return PhoneChangeToken, nil + + default: + return 0, fmt.Errorf("OneTimeTokenType: unrecognized string %q", s) + } +} + +func (t OneTimeTokenType) Value() (driver.Value, error) { + return t.String(), nil +} + +func (t *OneTimeTokenType) Scan(src interface{}) error { + s, ok := src.(string) + if !ok { + return fmt.Errorf("OneTimeTokenType: scan type is not string but is %T", src) + } + + parsed, err := ParseOneTimeTokenType(s) + if err != nil { + return err + } + + *t = parsed + return nil +} + +type OneTimeTokenNotFoundError struct { +} + +func (e OneTimeTokenNotFoundError) Error() string { + return "One-time token not found" +} + +type OneTimeToken struct { + ID uuid.UUID `json:"id" db:"id"` + + UserID uuid.UUID `json:"user_id" db:"user_id"` + TokenType OneTimeTokenType `json:"token_type" db:"token_type"` + + TokenHash string `json:"token_hash" db:"token_hash"` + RelatesTo string `json:"relates_to" db:"relates_to"` + + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` +} + +func (OneTimeToken) TableName() string { + return "one_time_tokens" +} + +func ClearAllOneTimeTokensForUser(tx *storage.Connection, userID uuid.UUID) error { + return tx.Q().Where("user_id = ?", userID).Delete(OneTimeToken{}) +} + +func ClearOneTimeTokenForUser(tx *storage.Connection, userID uuid.UUID, tokenType OneTimeTokenType) error { + if err := tx.Q().Where("token_type = ? and user_id = ?", tokenType, userID).Delete(OneTimeToken{}); err != nil { + return err + } + + return nil +} + +func CreateOneTimeToken(tx *storage.Connection, userID uuid.UUID, relatesTo, tokenHash string, tokenType OneTimeTokenType) error { + if err := ClearOneTimeTokenForUser(tx, userID, tokenType); err != nil { + return err + } + + oneTimeToken := &OneTimeToken{ + ID: uuid.Must(uuid.NewV4()), + UserID: userID, + TokenType: tokenType, + TokenHash: tokenHash, + RelatesTo: strings.ToLower(relatesTo), + } + + if err := tx.Eager().Create(oneTimeToken); err != nil { + return err + } + + return nil +} + +func FindOneTimeToken(tx *storage.Connection, tokenHash string, tokenTypes ...OneTimeTokenType) (*OneTimeToken, error) { + oneTimeToken := &OneTimeToken{} + + query := tx.Eager().Q() + + switch len(tokenTypes) { + case 2: + query = query.Where("(token_type = ? or token_type = ?) and token_hash = ?", tokenTypes[0], tokenTypes[1], tokenHash) + + case 1: + query = query.Where("token_type = ? and token_hash = ?", tokenTypes[0], tokenHash) + + default: + panic("at most 2 token types are accepted") + } + + if err := query.First(oneTimeToken); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, OneTimeTokenNotFoundError{} + } + + return nil, errors.Wrap(err, "error finding one time token") + } + + return oneTimeToken, nil +} + +// FindUserByConfirmationToken finds users with the matching confirmation token. +func FindUserByConfirmationOrRecoveryToken(tx *storage.Connection, token string) (*User, error) { + ott, err := FindOneTimeToken(tx, token, ConfirmationToken, RecoveryToken) + if err != nil { + return nil, err + } + + return FindUserByID(tx, ott.UserID) +} + +// FindUserByConfirmationToken finds users with the matching confirmation token. +func FindUserByConfirmationToken(tx *storage.Connection, token string) (*User, error) { + ott, err := FindOneTimeToken(tx, token, ConfirmationToken) + if err != nil { + return nil, err + } + + return FindUserByID(tx, ott.UserID) +} + +// FindUserByRecoveryToken finds a user with the matching recovery token. +func FindUserByRecoveryToken(tx *storage.Connection, token string) (*User, error) { + ott, err := FindOneTimeToken(tx, token, RecoveryToken) + if err != nil { + return nil, err + } + + return FindUserByID(tx, ott.UserID) +} + +// FindUserByEmailChangeToken finds a user with the matching email change token. +func FindUserByEmailChangeToken(tx *storage.Connection, token string) (*User, error) { + ott, err := FindOneTimeToken(tx, token, EmailChangeTokenCurrent, EmailChangeTokenNew) + if err != nil { + return nil, err + } + + return FindUserByID(tx, ott.UserID) +} + +// FindUserByEmailChangeCurrentAndAudience finds a user with the matching email change and audience. +func FindUserByEmailChangeCurrentAndAudience(tx *storage.Connection, email, token, aud string) (*User, error) { + ott, err := FindOneTimeToken(tx, token, EmailChangeTokenCurrent) + if err != nil { + return nil, err + } + + if ott == nil { + ott, err = FindOneTimeToken(tx, "pkce_"+token, EmailChangeTokenCurrent) + if err != nil { + return nil, err + } + } + if ott == nil { + return nil, err + } + + user, err := FindUserByID(tx, ott.UserID) + if err != nil { + return nil, err + } + + if user.Aud != aud && strings.EqualFold(user.GetEmail(), email) { + return nil, UserNotFoundError{} + } + + return user, nil +} + +// FindUserByEmailChangeNewAndAudience finds a user with the matching email change and audience. +func FindUserByEmailChangeNewAndAudience(tx *storage.Connection, email, token, aud string) (*User, error) { + ott, err := FindOneTimeToken(tx, token, EmailChangeTokenNew) + if err != nil && !IsNotFoundError(err) { + return nil, err + } + + if ott == nil { + ott, err = FindOneTimeToken(tx, "pkce_"+token, EmailChangeTokenNew) + if err != nil && !IsNotFoundError(err) { + return nil, err + } + } + if ott == nil { + return nil, err + } + + user, err := FindUserByID(tx, ott.UserID) + if err != nil { + return nil, err + } + + if user.Aud != aud && strings.EqualFold(user.EmailChange, email) { + return nil, UserNotFoundError{} + } + + return user, nil +} + +// FindUserForEmailChange finds a user requesting for an email change +func FindUserForEmailChange(tx *storage.Connection, email, token, aud string, secureEmailChangeEnabled bool) (*User, error) { + if secureEmailChangeEnabled { + if user, err := FindUserByEmailChangeCurrentAndAudience(tx, email, token, aud); err == nil { + return user, err + } else if !IsNotFoundError(err) { + return nil, err + } + } + return FindUserByEmailChangeNewAndAudience(tx, email, token, aud) +} diff --git a/auth_v2.169.0/internal/models/refresh_token.go b/auth_v2.169.0/internal/models/refresh_token.go new file mode 100644 index 0000000..c5fea83 --- /dev/null +++ b/auth_v2.169.0/internal/models/refresh_token.go @@ -0,0 +1,166 @@ +package models + +import ( + "database/sql" + "net/http" + "time" + + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" +) + +// RefreshToken is the database model for refresh tokens. +type RefreshToken struct { + ID int64 `db:"id"` + + Token string `db:"token"` + + UserID uuid.UUID `db:"user_id"` + + Parent storage.NullString `db:"parent"` + SessionId *uuid.UUID `db:"session_id"` + + Revoked bool `db:"revoked"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` + + DONTUSEINSTANCEID uuid.UUID `json:"-" db:"instance_id"` +} + +func (RefreshToken) TableName() string { + tableName := "refresh_tokens" + return tableName +} + +// GrantParams is used to pass session-specific parameters when issuing a new +// refresh token to authenticated users. +type GrantParams struct { + FactorID *uuid.UUID + + SessionNotAfter *time.Time + SessionTag *string + + UserAgent string + IP string +} + +func (g *GrantParams) FillGrantParams(r *http.Request) { + g.UserAgent = r.Header.Get("User-Agent") + g.IP = utilities.GetIPAddress(r) +} + +// GrantAuthenticatedUser creates a refresh token for the provided user. +func GrantAuthenticatedUser(tx *storage.Connection, user *User, params GrantParams) (*RefreshToken, error) { + return createRefreshToken(tx, user, nil, ¶ms) +} + +// GrantRefreshTokenSwap swaps a refresh token for a new one, revoking the provided token. +func GrantRefreshTokenSwap(r *http.Request, tx *storage.Connection, user *User, token *RefreshToken) (*RefreshToken, error) { + var newToken *RefreshToken + err := tx.Transaction(func(rtx *storage.Connection) error { + var terr error + if terr = NewAuditLogEntry(r, tx, user, TokenRevokedAction, "", nil); terr != nil { + return errors.Wrap(terr, "error creating audit log entry") + } + + token.Revoked = true + if terr = tx.UpdateOnly(token, "revoked"); terr != nil { + return terr + } + + newToken, terr = createRefreshToken(rtx, user, token, &GrantParams{}) + return terr + }) + return newToken, err +} + +// RevokeTokenFamily revokes all refresh tokens that descended from the provided token. +func RevokeTokenFamily(tx *storage.Connection, token *RefreshToken) error { + var err error + tablename := (&pop.Model{Value: RefreshToken{}}).TableName() + if token.SessionId != nil { + err = tx.RawQuery(`update `+tablename+` set revoked = true, updated_at = now() where session_id = ? and revoked = false;`, token.SessionId).Exec() + } else { + err = tx.RawQuery(` + with recursive token_family as ( + select id, user_id, token, revoked, parent from `+tablename+` where parent = ? + union + select r.id, r.user_id, r.token, r.revoked, r.parent from `+tablename+` r inner join token_family t on t.token = r.parent + ) + update `+tablename+` r set revoked = true from token_family where token_family.id = r.id;`, token.Token).Exec() + } + if err != nil { + if errors.Cause(err) == sql.ErrNoRows || errors.Is(err, sql.ErrNoRows) { + return nil + } + + return err + } + return nil +} + +func FindTokenBySessionID(tx *storage.Connection, sessionId *uuid.UUID) (*RefreshToken, error) { + refreshToken := &RefreshToken{} + err := tx.Q().Where("instance_id = ? and session_id = ?", uuid.Nil, sessionId).Order("created_at asc").First(refreshToken) + if err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, RefreshTokenNotFoundError{} + } + return nil, err + } + return refreshToken, nil +} + +func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshToken, params *GrantParams) (*RefreshToken, error) { + token := &RefreshToken{ + UserID: user.ID, + Token: crypto.SecureToken(), + Parent: "", + } + if oldToken != nil { + token.Parent = storage.NullString(oldToken.Token) + token.SessionId = oldToken.SessionId + } + + if token.SessionId == nil { + session, err := NewSession(user.ID, params.FactorID) + if err != nil { + return nil, errors.Wrap(err, "error instantiating new session object") + } + + if params.SessionNotAfter != nil { + session.NotAfter = params.SessionNotAfter + } + + if params.UserAgent != "" { + session.UserAgent = ¶ms.UserAgent + } + + if params.IP != "" { + session.IP = ¶ms.IP + } + + if params.SessionTag != nil && *params.SessionTag != "" { + session.Tag = params.SessionTag + } + + if err := tx.Create(session); err != nil { + return nil, errors.Wrap(err, "error creating new session") + } + + token.SessionId = &session.ID + } + + if err := tx.Create(token); err != nil { + return nil, errors.Wrap(err, "error creating refresh token") + } + + if err := user.UpdateLastSignInAt(tx); err != nil { + return nil, errors.Wrap(err, "error update user`s last_sign_in field") + } + return token, nil +} diff --git a/auth_v2.169.0/internal/models/refresh_token_test.go b/auth_v2.169.0/internal/models/refresh_token_test.go new file mode 100644 index 0000000..675826d --- /dev/null +++ b/auth_v2.169.0/internal/models/refresh_token_test.go @@ -0,0 +1,89 @@ +package models + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/storage/test" +) + +type RefreshTokenTestSuite struct { + suite.Suite + db *storage.Connection +} + +func (ts *RefreshTokenTestSuite) SetupTest() { + TruncateAll(ts.db) +} + +func TestRefreshToken(t *testing.T) { + globalConfig, err := conf.LoadGlobal(modelsTestConfig) + require.NoError(t, err) + + conn, err := test.SetupDBConnection(globalConfig) + require.NoError(t, err) + + ts := &RefreshTokenTestSuite{ + db: conn, + } + defer ts.db.Close() + + suite.Run(t, ts) +} + +func (ts *RefreshTokenTestSuite) TestGrantAuthenticatedUser() { + u := ts.createUser() + r, err := GrantAuthenticatedUser(ts.db, u, GrantParams{}) + require.NoError(ts.T(), err) + + require.NotEmpty(ts.T(), r.Token) + require.Equal(ts.T(), u.ID, r.UserID) +} + +func (ts *RefreshTokenTestSuite) TestGrantRefreshTokenSwap() { + u := ts.createUser() + r, err := GrantAuthenticatedUser(ts.db, u, GrantParams{}) + require.NoError(ts.T(), err) + + s, err := GrantRefreshTokenSwap(&http.Request{}, ts.db, u, r) + require.NoError(ts.T(), err) + + _, nr, _, err := FindUserWithRefreshToken(ts.db, r.Token, false) + require.NoError(ts.T(), err) + + require.Equal(ts.T(), r.ID, nr.ID) + require.True(ts.T(), nr.Revoked, "expected old token to be revoked") + + require.NotEqual(ts.T(), r.ID, s.ID) + require.Equal(ts.T(), u.ID, s.UserID) +} + +func (ts *RefreshTokenTestSuite) TestLogout() { + u := ts.createUser() + r, err := GrantAuthenticatedUser(ts.db, u, GrantParams{}) + require.NoError(ts.T(), err) + + require.NoError(ts.T(), Logout(ts.db, u.ID)) + u, r, _, err = FindUserWithRefreshToken(ts.db, r.Token, false) + require.Errorf(ts.T(), err, "expected error when there are no refresh tokens to authenticate. user: %v token: %v", u, r) + + require.True(ts.T(), IsNotFoundError(err), "expected NotFoundError") +} + +func (ts *RefreshTokenTestSuite) createUser() *User { + return ts.createUserWithEmail("david@netlify.com") +} + +func (ts *RefreshTokenTestSuite) createUserWithEmail(email string) *User { + user, err := NewUser("", email, "secret", "test", nil) + require.NoError(ts.T(), err) + + err = ts.db.Create(user) + require.NoError(ts.T(), err) + + return user +} diff --git a/auth_v2.169.0/internal/models/sessions.go b/auth_v2.169.0/internal/models/sessions.go new file mode 100644 index 0000000..a93be44 --- /dev/null +++ b/auth_v2.169.0/internal/models/sessions.go @@ -0,0 +1,356 @@ +package models + +import ( + "database/sql" + "fmt" + "sort" + "strings" + "time" + + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/storage" +) + +type AuthenticatorAssuranceLevel int + +const ( + AAL1 AuthenticatorAssuranceLevel = iota + AAL2 + AAL3 +) + +func (aal AuthenticatorAssuranceLevel) String() string { + switch aal { + case AAL1: + return "aal1" + case AAL2: + return "aal2" + case AAL3: + return "aal3" + default: + return "" + } +} + +// AMREntry represents a method that a user has logged in together with the corresponding time +type AMREntry struct { + Method string `json:"method"` + Timestamp int64 `json:"timestamp"` + Provider string `json:"provider,omitempty"` +} + +type sortAMREntries struct { + Array []AMREntry +} + +func (s sortAMREntries) Len() int { + return len(s.Array) +} + +func (s sortAMREntries) Less(i, j int) bool { + return s.Array[i].Timestamp < s.Array[j].Timestamp +} + +func (s sortAMREntries) Swap(i, j int) { + s.Array[j], s.Array[i] = s.Array[i], s.Array[j] +} + +type Session struct { + ID uuid.UUID `json:"-" db:"id"` + UserID uuid.UUID `json:"user_id" db:"user_id"` + + // NotAfter is overriden by timeboxed sessions. + NotAfter *time.Time `json:"not_after,omitempty" db:"not_after"` + + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + FactorID *uuid.UUID `json:"factor_id" db:"factor_id"` + AMRClaims []AMRClaim `json:"amr,omitempty" has_many:"amr_claims"` + AAL *string `json:"aal" db:"aal"` + + RefreshedAt *time.Time `json:"refreshed_at,omitempty" db:"refreshed_at"` + UserAgent *string `json:"user_agent,omitempty" db:"user_agent"` + IP *string `json:"ip,omitempty" db:"ip"` + + Tag *string `json:"tag" db:"tag"` +} + +func (Session) TableName() string { + tableName := "sessions" + return tableName +} + +func (s *Session) LastRefreshedAt(refreshTokenTime *time.Time) time.Time { + refreshedAt := s.RefreshedAt + + if refreshedAt == nil || refreshedAt.IsZero() { + if refreshTokenTime != nil { + rtt := *refreshTokenTime + + if rtt.IsZero() { + return s.CreatedAt + } else if rtt.After(s.CreatedAt) { + return rtt + } + } + + return s.CreatedAt + } + + return *refreshedAt +} + +func (s *Session) UpdateOnlyRefreshInfo(tx *storage.Connection) error { + // TODO(kangmingtay): The underlying database type uses timestamp without timezone, + // so we need to convert the value to UTC before updating it. + // In the future, we should add a migration to update the type to contain the timezone. + *s.RefreshedAt = s.RefreshedAt.UTC() + return tx.UpdateOnly(s, "refreshed_at", "user_agent", "ip") +} + +type SessionValidityReason = int + +const ( + SessionValid SessionValidityReason = iota + SessionPastNotAfter = iota + SessionPastTimebox = iota + SessionTimedOut = iota +) + +func (s *Session) CheckValidity(now time.Time, refreshTokenTime *time.Time, timebox, inactivityTimeout *time.Duration) SessionValidityReason { + if s.NotAfter != nil && now.After(*s.NotAfter) { + return SessionPastNotAfter + } + + if timebox != nil && *timebox != 0 && now.After(s.CreatedAt.Add(*timebox)) { + return SessionPastTimebox + } + + if inactivityTimeout != nil && *inactivityTimeout != 0 && now.After(s.LastRefreshedAt(refreshTokenTime).Add(*inactivityTimeout)) { + return SessionTimedOut + } + + return SessionValid +} + +func (s *Session) DetermineTag(tags []string) string { + if len(tags) == 0 { + return "" + } + + if s.Tag == nil { + return tags[0] + } + + tag := *s.Tag + if tag == "" { + return tags[0] + } + + for _, t := range tags { + if t == tag { + return tag + } + } + + return tags[0] +} + +func NewSession(userID uuid.UUID, factorID *uuid.UUID) (*Session, error) { + id := uuid.Must(uuid.NewV4()) + + defaultAAL := AAL1.String() + + session := &Session{ + ID: id, + AAL: &defaultAAL, + UserID: userID, + FactorID: factorID, + } + + return session, nil +} + +// FindSessionByID looks up a Session by the provided id. If forUpdate is set +// to true, then the SELECT statement used by the query has the form SELECT ... +// FOR UPDATE SKIP LOCKED. This means that a FOR UPDATE lock will only be +// acquired if there's no other lock. In case there is a lock, a +// IsNotFound(err) error will be retured. +func FindSessionByID(tx *storage.Connection, id uuid.UUID, forUpdate bool) (*Session, error) { + session := &Session{} + + if forUpdate { + // pop does not provide us with a way to execute FOR UPDATE + // queries which lock the rows affected by the query from + // being accessed by any other transaction that also uses FOR + // UPDATE + if err := tx.RawQuery(fmt.Sprintf("SELECT * FROM %q WHERE id = ? LIMIT 1 FOR UPDATE SKIP LOCKED;", session.TableName()), id).First(session); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, SessionNotFoundError{} + } + + return nil, err + } + } + + // once the rows are locked (if forUpdate was true), we can query again using pop + if err := tx.Eager().Q().Where("id = ?", id).First(session); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, SessionNotFoundError{} + } + return nil, errors.Wrap(err, "error finding session") + } + return session, nil +} + +func FindSessionByUserID(tx *storage.Connection, userId uuid.UUID) (*Session, error) { + session := &Session{} + if err := tx.Eager().Q().Where("user_id = ?", userId).Order("created_at asc").First(session); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, SessionNotFoundError{} + } + return nil, errors.Wrap(err, "error finding session") + } + return session, nil +} + +func FindSessionsByFactorID(tx *storage.Connection, factorID uuid.UUID) ([]*Session, error) { + sessions := []*Session{} + if err := tx.Q().Where("factor_id = ?", factorID).All(&sessions); err != nil { + return nil, err + } + return sessions, nil +} + +// FindAllSessionsForUser finds all of the sessions for a user. If forUpdate is +// set, it will first lock on the user row which can be used to prevent issues +// with concurrency. If the lock is acquired, it will return a +// UserNotFoundError and the operation should be retried. If there are no +// sessions for the user, a nil result is returned without an error. +func FindAllSessionsForUser(tx *storage.Connection, userId uuid.UUID, forUpdate bool) ([]*Session, error) { + if forUpdate { + user := &User{} + if err := tx.RawQuery(fmt.Sprintf("SELECT id FROM %q WHERE id = ? LIMIT 1 FOR UPDATE SKIP LOCKED;", user.TableName()), userId).First(user); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, UserNotFoundError{} + } + + return nil, err + } + } + + var sessions []*Session + if err := tx.Where("user_id = ?", userId).All(&sessions); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, nil + } + + return nil, err + } + + return sessions, nil +} + +func updateFactorAssociatedSessions(tx *storage.Connection, userID, factorID uuid.UUID, aal string) error { + return tx.RawQuery("UPDATE "+(&pop.Model{Value: Session{}}).TableName()+" set aal = ?, factor_id = ? WHERE user_id = ? AND factor_id = ?", aal, nil, userID, factorID).Exec() +} + +func InvalidateSessionsWithAALLessThan(tx *storage.Connection, userID uuid.UUID, level string) error { + return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Session{}}).TableName()+" WHERE user_id = ? AND aal < ?", userID, level).Exec() +} + +// Logout deletes all sessions for a user. +func Logout(tx *storage.Connection, userId uuid.UUID) error { + return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Session{}}).TableName()+" WHERE user_id = ?", userId).Exec() +} + +// LogoutSession deletes the current session for a user +func LogoutSession(tx *storage.Connection, sessionId uuid.UUID) error { + return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Session{}}).TableName()+" WHERE id = ?", sessionId).Exec() +} + +// LogoutAllExceptMe deletes all sessions for a user except the current one +func LogoutAllExceptMe(tx *storage.Connection, sessionId uuid.UUID, userID uuid.UUID) error { + return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Session{}}).TableName()+" WHERE id != ? AND user_id = ?", sessionId, userID).Exec() +} + +func (s *Session) UpdateAALAndAssociatedFactor(tx *storage.Connection, aal AuthenticatorAssuranceLevel, factorID *uuid.UUID) error { + s.FactorID = factorID + aalAsString := aal.String() + s.AAL = &aalAsString + return tx.UpdateOnly(s, "aal", "factor_id") +} + +func (s *Session) CalculateAALAndAMR(user *User) (aal AuthenticatorAssuranceLevel, amr []AMREntry, err error) { + amr, aal = []AMREntry{}, AAL1 + for _, claim := range s.AMRClaims { + if claim.IsAAL2Claim() { + aal = AAL2 + } + amr = append(amr, AMREntry{Method: claim.GetAuthenticationMethod(), Timestamp: claim.UpdatedAt.Unix()}) + } + + // makes sure that the AMR claims are always ordered most-recent first + + // sort in ascending order + sort.Sort(sortAMREntries{ + Array: amr, + }) + + // now reverse for descending order + _ = sort.Reverse(sortAMREntries{ + Array: amr, + }) + + lastIndex := len(amr) - 1 + + if lastIndex > -1 && amr[lastIndex].Method == SSOSAML.String() { + // initial AMR claim is from sso/saml, we need to add information + // about the provider that was used for the authentication + identities := user.Identities + + if len(identities) == 1 { + identity := identities[0] + + if identity.IsForSSOProvider() { + amr[lastIndex].Provider = strings.TrimPrefix(identity.Provider, "sso:") + } + } + + // otherwise we can't identify that this user account has only + // one SSO identity, so we are not encoding the provider at + // this time + } + + return aal, amr, nil +} + +func (s *Session) GetAAL() string { + if s.AAL == nil { + return "" + } + return *(s.AAL) +} + +func (s *Session) IsAAL2() bool { + return s.GetAAL() == AAL2.String() +} + +// FindCurrentlyActiveRefreshToken returns the currently active refresh +// token in the session. This is the last created (ordered by the serial +// primary key) non-revoked refresh token for the session. +func (s *Session) FindCurrentlyActiveRefreshToken(tx *storage.Connection) (*RefreshToken, error) { + var activeRefreshToken RefreshToken + + if err := tx.Q().Where("session_id = ? and revoked is false", s.ID).Order("id desc").First(&activeRefreshToken); err != nil { + if errors.Cause(err) == sql.ErrNoRows || errors.Is(err, sql.ErrNoRows) { + return nil, RefreshTokenNotFoundError{} + } + + return nil, err + } + + return &activeRefreshToken, nil +} diff --git a/auth_v2.169.0/internal/models/sessions_test.go b/auth_v2.169.0/internal/models/sessions_test.go new file mode 100644 index 0000000..9dce78e --- /dev/null +++ b/auth_v2.169.0/internal/models/sessions_test.go @@ -0,0 +1,104 @@ +package models + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/storage/test" +) + +type SessionsTestSuite struct { + suite.Suite + db *storage.Connection + Config *conf.GlobalConfiguration +} + +func (ts *SessionsTestSuite) SetupTest() { + TruncateAll(ts.db) + email := "test@example.com" + user, err := NewUser("", email, "secret", ts.Config.JWT.Aud, nil) + require.NoError(ts.T(), err) + + err = ts.db.Create(user) + require.NoError(ts.T(), err) +} + +func TestSession(t *testing.T) { + globalConfig, err := conf.LoadGlobal(modelsTestConfig) + require.NoError(t, err) + conn, err := test.SetupDBConnection(globalConfig) + require.NoError(t, err) + ts := &SessionsTestSuite{ + db: conn, + Config: globalConfig, + } + defer ts.db.Close() + suite.Run(t, ts) +} + +func (ts *SessionsTestSuite) TestFindBySessionIDWithForUpdate() { + u, err := FindUserByEmailAndAudience(ts.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + session, err := NewSession(u.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(session)) + + found, err := FindSessionByID(ts.db, session.ID, true) + require.NoError(ts.T(), err) + + require.Equal(ts.T(), session.ID, found.ID) +} + +func (ts *SessionsTestSuite) AddClaimAndReloadSession(session *Session, claim AuthenticationMethod) *Session { + err := AddClaimToSession(ts.db, session.ID, claim) + require.NoError(ts.T(), err) + session, err = FindSessionByID(ts.db, session.ID, false) + require.NoError(ts.T(), err) + return session +} + +func (ts *SessionsTestSuite) TestCalculateAALAndAMR() { + totalDistinctClaims := 3 + u, err := FindUserByEmailAndAudience(ts.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + session, err := NewSession(u.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(session)) + + session = ts.AddClaimAndReloadSession(session, PasswordGrant) + + firstClaimAddedTime := time.Now() + session = ts.AddClaimAndReloadSession(session, TOTPSignIn) + + _, _, err = session.CalculateAALAndAMR(u) + require.NoError(ts.T(), err) + + session = ts.AddClaimAndReloadSession(session, TOTPSignIn) + + session = ts.AddClaimAndReloadSession(session, SSOSAML) + + aal, amr, err := session.CalculateAALAndAMR(u) + require.NoError(ts.T(), err) + + require.Equal(ts.T(), AAL2, aal) + require.Equal(ts.T(), totalDistinctClaims, len(amr)) + + found := false + for _, claim := range session.AMRClaims { + if claim.GetAuthenticationMethod() == TOTPSignIn.String() { + require.True(ts.T(), firstClaimAddedTime.Before(claim.UpdatedAt)) + found = true + } + } + + for _, claim := range amr { + if claim.Method == SSOSAML.String() { + require.NotNil(ts.T(), claim.Provider) + } + } + require.True(ts.T(), found) +} diff --git a/auth_v2.169.0/internal/models/sso.go b/auth_v2.169.0/internal/models/sso.go new file mode 100644 index 0000000..28c2429 --- /dev/null +++ b/auth_v2.169.0/internal/models/sso.go @@ -0,0 +1,262 @@ +package models + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "reflect" + "strings" + "time" + + "github.com/crewjam/saml" + "github.com/crewjam/saml/samlsp" + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/storage" +) + +type SSOProvider struct { + ID uuid.UUID `db:"id" json:"id"` + + SAMLProvider SAMLProvider `has_one:"saml_providers" fk_id:"sso_provider_id" json:"saml,omitempty"` + SSODomains []SSODomain `has_many:"sso_domains" fk_id:"sso_provider_id" json:"domains"` + + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} + +func (p SSOProvider) TableName() string { + return "sso_providers" +} + +func (p SSOProvider) Type() string { + return "saml" +} + +type SAMLAttribute struct { + Name string `json:"name,omitempty"` + Names []string `json:"names,omitempty"` + Default interface{} `json:"default,omitempty"` + Array bool `json:"array,omitempty"` +} + +type SAMLAttributeMapping struct { + Keys map[string]SAMLAttribute `json:"keys,omitempty"` +} + +func (m *SAMLAttributeMapping) Equal(o *SAMLAttributeMapping) bool { + if m == o { + return true + } + + if m == nil || o == nil { + return false + } + + if m.Keys == nil && o.Keys == nil { + return true + } + + if len(m.Keys) != len(o.Keys) { + return false + } + + for mkey, mvalue := range m.Keys { + value, ok := o.Keys[mkey] + if !ok { + return false + } + + if mvalue.Name != value.Name || len(mvalue.Names) != len(value.Names) { + return false + } + + for i := 0; i < len(mvalue.Names); i += 1 { + if mvalue.Names[i] != value.Names[i] { + return false + } + } + + if !reflect.DeepEqual(mvalue.Default, value.Default) { + return false + } + + if mvalue.Array != value.Array { + return false + } + } + + return true +} + +func (m *SAMLAttributeMapping) Scan(src interface{}) error { + b, ok := src.([]byte) + if !ok { + return errors.New("scan source was not []byte") + } + err := json.Unmarshal(b, m) + if err != nil { + return err + } + return nil +} + +func (m SAMLAttributeMapping) Value() (driver.Value, error) { + b, err := json.Marshal(m) + if err != nil { + return nil, err + } + return string(b), nil +} + +type SAMLProvider struct { + ID uuid.UUID `db:"id" json:"-"` + + SSOProvider *SSOProvider `belongs_to:"sso_providers" json:"-"` + SSOProviderID uuid.UUID `db:"sso_provider_id" json:"-"` + + EntityID string `db:"entity_id" json:"entity_id"` + MetadataXML string `db:"metadata_xml" json:"metadata_xml,omitempty"` + MetadataURL *string `db:"metadata_url" json:"metadata_url,omitempty"` + + AttributeMapping SAMLAttributeMapping `db:"attribute_mapping" json:"attribute_mapping,omitempty"` + + NameIDFormat *string `db:"name_id_format" json:"name_id_format,omitempty"` + + CreatedAt time.Time `db:"created_at" json:"-"` + UpdatedAt time.Time `db:"updated_at" json:"-"` +} + +func (p SAMLProvider) TableName() string { + return "saml_providers" +} + +func (p SAMLProvider) EntityDescriptor() (*saml.EntityDescriptor, error) { + return samlsp.ParseMetadata([]byte(p.MetadataXML)) +} + +type SSODomain struct { + ID uuid.UUID `db:"id" json:"-"` + + SSOProvider *SSOProvider `belongs_to:"sso_providers" json:"-"` + SSOProviderID uuid.UUID `db:"sso_provider_id" json:"-"` + + Domain string `db:"domain" json:"domain"` + + CreatedAt time.Time `db:"created_at" json:"-"` + UpdatedAt time.Time `db:"updated_at" json:"-"` +} + +func (d SSODomain) TableName() string { + return "sso_domains" +} + +type SAMLRelayState struct { + ID uuid.UUID `db:"id"` + + SSOProviderID uuid.UUID `db:"sso_provider_id"` + + RequestID string `db:"request_id"` + ForEmail *string `db:"for_email"` + + RedirectTo string `db:"redirect_to"` + + CreatedAt time.Time `db:"created_at" json:"-"` + UpdatedAt time.Time `db:"updated_at" json:"-"` + FlowStateID *uuid.UUID `db:"flow_state_id" json:"flow_state_id,omitempty"` + FlowState *FlowState `db:"-" json:"flow_state,omitempty" belongs_to:"flow_state"` +} + +func (s SAMLRelayState) TableName() string { + return "saml_relay_states" +} + +func FindSAMLProviderByEntityID(tx *storage.Connection, entityId string) (*SSOProvider, error) { + var samlProvider SAMLProvider + if err := tx.Q().Where("entity_id = ?", entityId).First(&samlProvider); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, SSOProviderNotFoundError{} + } + + return nil, errors.Wrap(err, "error finding SAML SSO provider by EntityID") + } + + var ssoProvider SSOProvider + if err := tx.Eager().Q().Where("id = ?", samlProvider.SSOProviderID).First(&ssoProvider); err != nil { + return nil, errors.Wrap(err, "error finding SAML SSO provider by ID (via EntityID)") + } + + return &ssoProvider, nil +} + +func FindSSOProviderByID(tx *storage.Connection, id uuid.UUID) (*SSOProvider, error) { + var ssoProvider SSOProvider + + if err := tx.Eager().Q().Where("id = ?", id).First(&ssoProvider); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, SSOProviderNotFoundError{} + } + + return nil, errors.Wrap(err, "error finding SAML SSO provider by ID") + } + + return &ssoProvider, nil +} + +func FindSSOProviderForEmailAddress(tx *storage.Connection, emailAddress string) (*SSOProvider, error) { + parts := strings.Split(emailAddress, "@") + emailDomain := strings.ToLower(parts[1]) + + return FindSSOProviderByDomain(tx, emailDomain) +} + +func FindSSOProviderByDomain(tx *storage.Connection, domain string) (*SSOProvider, error) { + var ssoDomain SSODomain + + if err := tx.Q().Where("domain = ?", domain).First(&ssoDomain); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, SSOProviderNotFoundError{} + } + + return nil, errors.Wrap(err, "error finding SAML SSO domain") + } + + var ssoProvider SSOProvider + if err := tx.Eager().Q().Where("id = ?", ssoDomain.SSOProviderID).First(&ssoProvider); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, SSOProviderNotFoundError{} + } + + return nil, errors.Wrap(err, "error finding SAML SSO provider by ID (via domain)") + } + + return &ssoProvider, nil +} + +func FindAllSAMLProviders(tx *storage.Connection) ([]SSOProvider, error) { + var providers []SSOProvider + + if err := tx.Eager().All(&providers); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, nil + } + + return nil, errors.Wrap(err, "error loading all SAML SSO providers") + } + + return providers, nil +} + +func FindSAMLRelayStateByID(tx *storage.Connection, id uuid.UUID) (*SAMLRelayState, error) { + var state SAMLRelayState + + if err := tx.Eager().Q().Where("id = ?", id).First(&state); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, SAMLRelayStateNotFoundError{} + } + + return nil, errors.Wrap(err, "error loading SAML Relay State") + } + + return &state, nil +} diff --git a/auth_v2.169.0/internal/models/sso_test.go b/auth_v2.169.0/internal/models/sso_test.go new file mode 100644 index 0000000..b6c9656 --- /dev/null +++ b/auth_v2.169.0/internal/models/sso_test.go @@ -0,0 +1,232 @@ +package models + +import ( + tst "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/storage/test" +) + +type SSOTestSuite struct { + suite.Suite + + db *storage.Connection +} + +func (ts *SSOTestSuite) SetupTest() { + TruncateAll(ts.db) +} + +func TestSSO(t *tst.T) { + globalConfig, err := conf.LoadGlobal(modelsTestConfig) + require.NoError(t, err) + + conn, err := test.SetupDBConnection(globalConfig) + require.NoError(t, err) + + ts := &SSOTestSuite{ + db: conn, + } + defer ts.db.Close() + + suite.Run(t, ts) +} + +func (ts *SSOTestSuite) TestConstraints() { + type exampleSpec struct { + Provider *SSOProvider + } + + examples := []exampleSpec{ + { + Provider: &SSOProvider{ + SAMLProvider: SAMLProvider{ + EntityID: "", + MetadataXML: "", + }, + }, + }, + { + Provider: &SSOProvider{ + SAMLProvider: SAMLProvider{ + EntityID: "https://example.com/saml/metadata", + MetadataXML: "", + }, + }, + }, + { + Provider: &SSOProvider{ + SAMLProvider: SAMLProvider{ + EntityID: "https://example.com/saml/metadata", + MetadataXML: "", + }, + SSODomains: []SSODomain{ + { + Domain: "", + }, + }, + }, + }, + } + + for i, example := range examples { + require.Error(ts.T(), ts.db.Eager().Create(example.Provider), "Example %d should have failed with error", i) + } +} + +func (ts *SSOTestSuite) TestDomainUniqueness() { + require.NoError(ts.T(), ts.db.Eager().Create(&SSOProvider{ + SAMLProvider: SAMLProvider{ + EntityID: "https://example.com/saml/metadata1", + MetadataXML: "", + }, + SSODomains: []SSODomain{ + { + Domain: "example.com", + }, + }, + })) + + require.Error(ts.T(), ts.db.Eager().Create(&SSOProvider{ + SAMLProvider: SAMLProvider{ + EntityID: "https://example.com/saml/metadata2", + MetadataXML: "", + }, + SSODomains: []SSODomain{ + { + Domain: "example.com", + }, + }, + })) +} + +func (ts *SSOTestSuite) TestEntityIDUniqueness() { + require.NoError(ts.T(), ts.db.Eager().Create(&SSOProvider{ + SAMLProvider: SAMLProvider{ + EntityID: "https://example.com/saml/metadata", + MetadataXML: "", + }, + SSODomains: []SSODomain{ + { + Domain: "example.com", + }, + }, + })) + + require.Error(ts.T(), ts.db.Eager().Create(&SSOProvider{ + SAMLProvider: SAMLProvider{ + EntityID: "https://example.com/saml/metadata", + MetadataXML: "", + }, + SSODomains: []SSODomain{ + { + Domain: "example.net", + }, + }, + })) +} + +func (ts *SSOTestSuite) TestFindSSOProviderForEmailAddress() { + provider := &SSOProvider{ + SAMLProvider: SAMLProvider{ + EntityID: "https://example.com/saml/metadata", + MetadataXML: "", + }, + SSODomains: []SSODomain{ + { + Domain: "example.com", + }, + { + Domain: "example.org", + }, + }, + } + + require.NoError(ts.T(), ts.db.Eager().Create(provider), "provider creation failed") + + type exampleSpec struct { + Address string + Provider *SSOProvider + } + + examples := []exampleSpec{ + { + Address: "someone@example.com", + Provider: provider, + }, + { + Address: "someone@example.org", + Provider: provider, + }, + { + Address: "someone@example.net", + Provider: nil, + }, + } + + for i, example := range examples { + rp, err := FindSSOProviderForEmailAddress(ts.db, example.Address) + + if nil == example.Provider { + require.Nil(ts.T(), rp) + require.True(ts.T(), IsNotFoundError(err), "Example %d failed with error %w", i, err) + } else { + require.Nil(ts.T(), err, "Example %d failed with error %w", i, err) + require.Equal(ts.T(), rp.ID, example.Provider.ID) + } + } +} + +func (ts *SSOTestSuite) TestFindSAMLProviderByEntityID() { + provider := &SSOProvider{ + SAMLProvider: SAMLProvider{ + EntityID: "https://example.com/saml/metadata", + MetadataXML: "", + }, + SSODomains: []SSODomain{ + { + Domain: "example.com", + }, + { + Domain: "example.org", + }, + }, + } + + require.NoError(ts.T(), ts.db.Eager().Create(provider)) + + type exampleSpec struct { + EntityID string + Provider *SSOProvider + } + + examples := []exampleSpec{ + { + EntityID: "https://example.com/saml/metadata", + Provider: provider, + }, + { + EntityID: "https://example.com/saml/metadata/", + Provider: nil, + }, + { + EntityID: "", + Provider: nil, + }, + } + + for i, example := range examples { + rp, err := FindSAMLProviderByEntityID(ts.db, example.EntityID) + + if nil == example.Provider { + require.True(ts.T(), IsNotFoundError(err), "Example %d failed with error", i) + require.Nil(ts.T(), rp) + } else { + require.Nil(ts.T(), err, "Example %d failed with error %w", i, err) + require.Equal(ts.T(), rp.ID, example.Provider.ID) + } + } +} diff --git a/auth_v2.169.0/internal/models/user.go b/auth_v2.169.0/internal/models/user.go new file mode 100644 index 0000000..3b16a54 --- /dev/null +++ b/auth_v2.169.0/internal/models/user.go @@ -0,0 +1,989 @@ +package models + +import ( + "context" + "crypto/sha256" + "database/sql" + "encoding/base64" + "fmt" + "strings" + "time" + + "github.com/go-webauthn/webauthn/webauthn" + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/storage" + "golang.org/x/crypto/bcrypt" +) + +// User respresents a registered user with email/password authentication +type User struct { + ID uuid.UUID `json:"id" db:"id"` + + Aud string `json:"aud" db:"aud"` + Role string `json:"role" db:"role"` + Email storage.NullString `json:"email" db:"email"` + IsSSOUser bool `json:"-" db:"is_sso_user"` + + EncryptedPassword *string `json:"-" db:"encrypted_password"` + EmailConfirmedAt *time.Time `json:"email_confirmed_at,omitempty" db:"email_confirmed_at"` + InvitedAt *time.Time `json:"invited_at,omitempty" db:"invited_at"` + + Phone storage.NullString `json:"phone" db:"phone"` + PhoneConfirmedAt *time.Time `json:"phone_confirmed_at,omitempty" db:"phone_confirmed_at"` + + ConfirmationToken string `json:"-" db:"confirmation_token"` + ConfirmationSentAt *time.Time `json:"confirmation_sent_at,omitempty" db:"confirmation_sent_at"` + + // For backward compatibility only. Use EmailConfirmedAt or PhoneConfirmedAt instead. + ConfirmedAt *time.Time `json:"confirmed_at,omitempty" db:"confirmed_at" rw:"r"` + + RecoveryToken string `json:"-" db:"recovery_token"` + RecoverySentAt *time.Time `json:"recovery_sent_at,omitempty" db:"recovery_sent_at"` + + EmailChangeTokenCurrent string `json:"-" db:"email_change_token_current"` + EmailChangeTokenNew string `json:"-" db:"email_change_token_new"` + EmailChange string `json:"new_email,omitempty" db:"email_change"` + EmailChangeSentAt *time.Time `json:"email_change_sent_at,omitempty" db:"email_change_sent_at"` + EmailChangeConfirmStatus int `json:"-" db:"email_change_confirm_status"` + + PhoneChangeToken string `json:"-" db:"phone_change_token"` + PhoneChange string `json:"new_phone,omitempty" db:"phone_change"` + PhoneChangeSentAt *time.Time `json:"phone_change_sent_at,omitempty" db:"phone_change_sent_at"` + + ReauthenticationToken string `json:"-" db:"reauthentication_token"` + ReauthenticationSentAt *time.Time `json:"reauthentication_sent_at,omitempty" db:"reauthentication_sent_at"` + + LastSignInAt *time.Time `json:"last_sign_in_at,omitempty" db:"last_sign_in_at"` + + AppMetaData JSONMap `json:"app_metadata" db:"raw_app_meta_data"` + UserMetaData JSONMap `json:"user_metadata" db:"raw_user_meta_data"` + + Factors []Factor `json:"factors,omitempty" has_many:"factors"` + Identities []Identity `json:"identities" has_many:"identities"` + + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + BannedUntil *time.Time `json:"banned_until,omitempty" db:"banned_until"` + DeletedAt *time.Time `json:"deleted_at,omitempty" db:"deleted_at"` + IsAnonymous bool `json:"is_anonymous" db:"is_anonymous"` + + DONTUSEINSTANCEID uuid.UUID `json:"-" db:"instance_id"` +} + +func NewUserWithPasswordHash(phone, email, passwordHash, aud string, userData map[string]interface{}) (*User, error) { + if strings.HasPrefix(passwordHash, crypto.Argon2Prefix) { + _, err := crypto.ParseArgon2Hash(passwordHash) + if err != nil { + return nil, err + } + } else if strings.HasPrefix(passwordHash, crypto.FirebaseScryptPrefix) { + _, err := crypto.ParseFirebaseScryptHash(passwordHash) + if err != nil { + return nil, err + } + } else { + // verify that the hash is a bcrypt hash + _, err := bcrypt.Cost([]byte(passwordHash)) + if err != nil { + return nil, err + } + } + id := uuid.Must(uuid.NewV4()) + user := &User{ + ID: id, + Aud: aud, + Email: storage.NullString(strings.ToLower(email)), + Phone: storage.NullString(phone), + UserMetaData: userData, + EncryptedPassword: &passwordHash, + } + return user, nil +} + +// NewUser initializes a new user from an email, password and user data. +func NewUser(phone, email, password, aud string, userData map[string]interface{}) (*User, error) { + passwordHash := "" + + if password != "" { + pw, err := crypto.GenerateFromPassword(context.Background(), password) + if err != nil { + return nil, err + } + + passwordHash = pw + } + + if userData == nil { + userData = make(map[string]interface{}) + } + + id := uuid.Must(uuid.NewV4()) + user := &User{ + ID: id, + Aud: aud, + Email: storage.NullString(strings.ToLower(email)), + Phone: storage.NullString(phone), + UserMetaData: userData, + EncryptedPassword: &passwordHash, + } + return user, nil +} + +// TableName overrides the table name used by pop +func (User) TableName() string { + tableName := "users" + return tableName +} + +func (u *User) HasPassword() bool { + var pwd string + + if u.EncryptedPassword != nil { + pwd = *u.EncryptedPassword + } + + return pwd != "" +} + +// BeforeSave is invoked before the user is saved to the database +func (u *User) BeforeSave(tx *pop.Connection) error { + if u.EmailConfirmedAt != nil && u.EmailConfirmedAt.IsZero() { + u.EmailConfirmedAt = nil + } + if u.PhoneConfirmedAt != nil && u.PhoneConfirmedAt.IsZero() { + u.PhoneConfirmedAt = nil + } + if u.InvitedAt != nil && u.InvitedAt.IsZero() { + u.InvitedAt = nil + } + if u.ConfirmationSentAt != nil && u.ConfirmationSentAt.IsZero() { + u.ConfirmationSentAt = nil + } + if u.RecoverySentAt != nil && u.RecoverySentAt.IsZero() { + u.RecoverySentAt = nil + } + if u.EmailChangeSentAt != nil && u.EmailChangeSentAt.IsZero() { + u.EmailChangeSentAt = nil + } + if u.PhoneChangeSentAt != nil && u.PhoneChangeSentAt.IsZero() { + u.PhoneChangeSentAt = nil + } + if u.ReauthenticationSentAt != nil && u.ReauthenticationSentAt.IsZero() { + u.ReauthenticationSentAt = nil + } + if u.LastSignInAt != nil && u.LastSignInAt.IsZero() { + u.LastSignInAt = nil + } + if u.BannedUntil != nil && u.BannedUntil.IsZero() { + u.BannedUntil = nil + } + return nil +} + +// IsConfirmed checks if a user has already been +// registered and confirmed. +func (u *User) IsConfirmed() bool { + return u.EmailConfirmedAt != nil +} + +// HasBeenInvited checks if user has been invited +func (u *User) HasBeenInvited() bool { + return u.InvitedAt != nil +} + +// IsPhoneConfirmed checks if a user's phone has already been +// registered and confirmed. +func (u *User) IsPhoneConfirmed() bool { + return u.PhoneConfirmedAt != nil +} + +// SetRole sets the users Role to roleName +func (u *User) SetRole(tx *storage.Connection, roleName string) error { + u.Role = strings.TrimSpace(roleName) + return tx.UpdateOnly(u, "role") +} + +// HasRole returns true when the users role is set to roleName +func (u *User) HasRole(roleName string) bool { + return u.Role == roleName +} + +// GetEmail returns the user's email as a string +func (u *User) GetEmail() string { + return string(u.Email) +} + +// GetPhone returns the user's phone number as a string +func (u *User) GetPhone() string { + return string(u.Phone) +} + +// UpdateUserMetaData sets all user data from a map of updates, +// ensuring that it doesn't override attributes that are not +// in the provided map. +func (u *User) UpdateUserMetaData(tx *storage.Connection, updates map[string]interface{}) error { + if u.UserMetaData == nil { + u.UserMetaData = updates + } else { + for key, value := range updates { + if value != nil { + u.UserMetaData[key] = value + } else { + delete(u.UserMetaData, key) + } + } + } + return tx.UpdateOnly(u, "raw_user_meta_data") +} + +// UpdateAppMetaData updates all app data from a map of updates +func (u *User) UpdateAppMetaData(tx *storage.Connection, updates map[string]interface{}) error { + if u.AppMetaData == nil { + u.AppMetaData = updates + } else { + for key, value := range updates { + if value != nil { + u.AppMetaData[key] = value + } else { + delete(u.AppMetaData, key) + } + } + } + return tx.UpdateOnly(u, "raw_app_meta_data") +} + +// UpdateAppMetaDataProviders updates the provider field in AppMetaData column +func (u *User) UpdateAppMetaDataProviders(tx *storage.Connection) error { + providers, terr := FindProvidersByUser(tx, u) + if terr != nil { + return terr + } + payload := map[string]interface{}{ + "providers": providers, + } + if len(providers) > 0 { + payload["provider"] = providers[0] + } + return u.UpdateAppMetaData(tx, payload) +} + +// UpdateUserEmail updates the user's email to one of the identity's email +// if the current email used doesn't match any of the identities email +func (u *User) UpdateUserEmailFromIdentities(tx *storage.Connection) error { + identities, terr := FindIdentitiesByUserID(tx, u.ID) + if terr != nil { + return terr + } + for _, i := range identities { + if u.GetEmail() == i.GetEmail() { + // there's an existing identity that uses the same email + // so the user's email can be kept + return nil + } + } + + var primaryIdentity *Identity + for _, i := range identities { + if _, terr := FindUserByEmailAndAudience(tx, i.GetEmail(), u.Aud); terr != nil { + if IsNotFoundError(terr) { + // the identity's email is not used by another user + // so we can set it as the primary identity + primaryIdentity = i + break + } + return terr + } + } + if primaryIdentity == nil { + return UserEmailUniqueConflictError{} + } + // default to the first identity's email + if terr := u.SetEmail(tx, primaryIdentity.GetEmail()); terr != nil { + return terr + } + if primaryIdentity.GetEmail() == "" { + u.EmailConfirmedAt = nil + if terr := tx.UpdateOnly(u, "email_confirmed_at"); terr != nil { + return terr + } + } + return nil +} + +// SetEmail sets the user's email +func (u *User) SetEmail(tx *storage.Connection, email string) error { + u.Email = storage.NullString(email) + return tx.UpdateOnly(u, "email") +} + +// SetPhone sets the user's phone +func (u *User) SetPhone(tx *storage.Connection, phone string) error { + u.Phone = storage.NullString(phone) + return tx.UpdateOnly(u, "phone") +} + +func (u *User) SetPassword(ctx context.Context, password string, encrypt bool, encryptionKeyID, encryptionKey string) error { + if password == "" { + u.EncryptedPassword = nil + return nil + } + + pw, err := crypto.GenerateFromPassword(ctx, password) + if err != nil { + return err + } + + u.EncryptedPassword = &pw + if encrypt { + es, err := crypto.NewEncryptedString(u.ID.String(), []byte(pw), encryptionKeyID, encryptionKey) + if err != nil { + return err + } + + encryptedPassword := es.String() + u.EncryptedPassword = &encryptedPassword + } + + return nil +} + +// UpdatePassword updates the user's password. Use SetPassword outside of a transaction first! +func (u *User) UpdatePassword(tx *storage.Connection, sessionID *uuid.UUID) error { + // These need to be reset because password change may mean the user no longer trusts the actions performed by the previous password. + u.ConfirmationToken = "" + u.ConfirmationSentAt = nil + u.RecoveryToken = "" + u.RecoverySentAt = nil + u.EmailChangeTokenCurrent = "" + u.EmailChangeTokenNew = "" + u.EmailChangeSentAt = nil + u.PhoneChangeToken = "" + u.PhoneChangeSentAt = nil + u.ReauthenticationToken = "" + u.ReauthenticationSentAt = nil + + if err := tx.UpdateOnly(u, "encrypted_password", "confirmation_token", "confirmation_sent_at", "recovery_token", "recovery_sent_at", "email_change_token_current", "email_change_token_new", "email_change_sent_at", "phone_change_token", "phone_change_sent_at", "reauthentication_token", "reauthentication_sent_at"); err != nil { + return err + } + + if err := ClearAllOneTimeTokensForUser(tx, u.ID); err != nil { + return err + } + + if sessionID == nil { + // log out user from all sessions to ensure reauthentication after password change + return Logout(tx, u.ID) + } else { + // log out user from all other sessions to ensure reauthentication after password change + return LogoutAllExceptMe(tx, *sessionID, u.ID) + } +} + +// Authenticate a user from a password +func (u *User) Authenticate(ctx context.Context, tx *storage.Connection, password string, decryptionKeys map[string]string, encrypt bool, encryptionKeyID string) (bool, bool, error) { + if u.EncryptedPassword == nil { + return false, false, nil + } + + hash := *u.EncryptedPassword + + if hash == "" { + return false, false, nil + } + + es := crypto.ParseEncryptedString(hash) + if es != nil { + h, err := es.Decrypt(u.ID.String(), decryptionKeys) + if err != nil { + return false, false, err + } + + hash = string(h) + } + + compareErr := crypto.CompareHashAndPassword(ctx, hash, password) + + if !strings.HasPrefix(hash, crypto.Argon2Prefix) && !strings.HasPrefix(hash, crypto.FirebaseScryptPrefix) { + // check if cost exceeds default cost or is too low + cost, err := bcrypt.Cost([]byte(hash)) + if err != nil { + return compareErr == nil, false, err + } + + if cost > bcrypt.DefaultCost || cost == bcrypt.MinCost { + // don't bother with encrypting the password in Authenticate + // since it's handled separately + if err := u.SetPassword(ctx, password, false, "", ""); err != nil { + return compareErr == nil, false, err + } + } + } + + return compareErr == nil, encrypt && (es == nil || es.ShouldReEncrypt(encryptionKeyID)), nil +} + +// ConfirmReauthentication resets the reauthentication token +func (u *User) ConfirmReauthentication(tx *storage.Connection) error { + u.ReauthenticationToken = "" + if err := tx.UpdateOnly(u, "reauthentication_token"); err != nil { + return err + } + + if err := ClearAllOneTimeTokensForUser(tx, u.ID); err != nil { + return err + } + + return nil +} + +// Confirm resets the confimation token and sets the confirm timestamp +func (u *User) Confirm(tx *storage.Connection) error { + u.ConfirmationToken = "" + now := time.Now() + u.EmailConfirmedAt = &now + + if err := tx.UpdateOnly(u, "confirmation_token", "email_confirmed_at"); err != nil { + return err + } + + if err := u.UpdateUserMetaData(tx, map[string]interface{}{ + "email_verified": true, + }); err != nil { + return err + } + + if err := ClearAllOneTimeTokensForUser(tx, u.ID); err != nil { + return err + } + + return nil +} + +// ConfirmPhone resets the confimation token and sets the confirm timestamp +func (u *User) ConfirmPhone(tx *storage.Connection) error { + u.ConfirmationToken = "" + now := time.Now() + u.PhoneConfirmedAt = &now + if err := tx.UpdateOnly(u, "confirmation_token", "phone_confirmed_at"); err != nil { + return nil + } + + return ClearAllOneTimeTokensForUser(tx, u.ID) +} + +// UpdateLastSignInAt update field last_sign_in_at for user according to specified field +func (u *User) UpdateLastSignInAt(tx *storage.Connection) error { + return tx.UpdateOnly(u, "last_sign_in_at") +} + +// ConfirmEmailChange confirm the change of email for a user +func (u *User) ConfirmEmailChange(tx *storage.Connection, status int) error { + email := u.EmailChange + + u.Email = storage.NullString(email) + u.EmailChange = "" + u.EmailChangeTokenCurrent = "" + u.EmailChangeTokenNew = "" + u.EmailChangeConfirmStatus = status + + if err := tx.UpdateOnly( + u, + "email", + "email_change", + "email_change_token_current", + "email_change_token_new", + "email_change_confirm_status", + ); err != nil { + return err + } + + if err := ClearAllOneTimeTokensForUser(tx, u.ID); err != nil { + return err + } + + if !u.IsConfirmed() { + if err := u.Confirm(tx); err != nil { + return err + } + } + + identity, err := FindIdentityByIdAndProvider(tx, u.ID.String(), "email") + if err != nil { + if IsNotFoundError(err) { + // no email identity, not an error + return nil + } + return err + } + + if _, ok := identity.IdentityData["email"]; ok { + identity.IdentityData["email"] = email + if err := tx.UpdateOnly(identity, "identity_data"); err != nil { + return err + } + } + + return nil +} + +// ConfirmPhoneChange confirms the change of phone for a user +func (u *User) ConfirmPhoneChange(tx *storage.Connection) error { + now := time.Now() + phone := u.PhoneChange + + u.Phone = storage.NullString(phone) + u.PhoneChange = "" + u.PhoneChangeToken = "" + u.PhoneConfirmedAt = &now + + if err := tx.UpdateOnly( + u, + "phone", + "phone_change", + "phone_change_token", + "phone_confirmed_at", + ); err != nil { + return err + } + + if err := ClearAllOneTimeTokensForUser(tx, u.ID); err != nil { + return err + } + + identity, err := FindIdentityByIdAndProvider(tx, u.ID.String(), "phone") + if err != nil { + if IsNotFoundError(err) { + // no phone identity, not an error + return nil + } + + return err + } + + if _, ok := identity.IdentityData["phone"]; ok { + identity.IdentityData["phone"] = phone + } + + if err := tx.UpdateOnly(identity, "identity_data"); err != nil { + return err + } + + return nil +} + +// Recover resets the recovery token +func (u *User) Recover(tx *storage.Connection) error { + u.RecoveryToken = "" + if err := tx.UpdateOnly(u, "recovery_token"); err != nil { + return err + } + + return ClearAllOneTimeTokensForUser(tx, u.ID) +} + +// CountOtherUsers counts how many other users exist besides the one provided +func CountOtherUsers(tx *storage.Connection, id uuid.UUID) (int, error) { + userCount, err := tx.Q().Where("instance_id = ? and id != ?", uuid.Nil, id).Count(&User{}) + return userCount, errors.Wrap(err, "error finding registered users") +} + +func findUser(tx *storage.Connection, query string, args ...interface{}) (*User, error) { + obj := &User{} + if err := tx.Eager().Q().Where(query, args...).First(obj); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, UserNotFoundError{} + } + return nil, errors.Wrap(err, "error finding user") + } + + return obj, nil +} + +// FindUserByEmailAndAudience finds a user with the matching email and audience. +func FindUserByEmailAndAudience(tx *storage.Connection, email, aud string) (*User, error) { + return findUser(tx, "instance_id = ? and LOWER(email) = ? and aud = ? and is_sso_user = false", uuid.Nil, strings.ToLower(email), aud) +} + +// FindUserByPhoneAndAudience finds a user with the matching email and audience. +func FindUserByPhoneAndAudience(tx *storage.Connection, phone, aud string) (*User, error) { + return findUser(tx, "instance_id = ? and phone = ? and aud = ? and is_sso_user = false", uuid.Nil, phone, aud) +} + +// FindUserByID finds a user matching the provided ID. +func FindUserByID(tx *storage.Connection, id uuid.UUID) (*User, error) { + return findUser(tx, "instance_id = ? and id = ?", uuid.Nil, id) +} + +// FindUserWithRefreshToken finds a user from the provided refresh token. If +// forUpdate is set to true, then the SELECT statement used by the query has +// the form SELECT ... FOR UPDATE SKIP LOCKED. This means that a FOR UPDATE +// lock will only be acquired if there's no other lock. In case there is a +// lock, a IsNotFound(err) error will be returned. +func FindUserWithRefreshToken(tx *storage.Connection, token string, forUpdate bool) (*User, *RefreshToken, *Session, error) { + refreshToken := &RefreshToken{} + + if forUpdate { + // pop does not provide us with a way to execute FOR UPDATE + // queries which lock the rows affected by the query from + // being accessed by any other transaction that also uses FOR + // UPDATE + if err := tx.RawQuery(fmt.Sprintf("SELECT * FROM %q WHERE token = ? LIMIT 1 FOR UPDATE SKIP LOCKED;", refreshToken.TableName()), token).First(refreshToken); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, nil, nil, RefreshTokenNotFoundError{} + } + + return nil, nil, nil, errors.Wrap(err, "error finding refresh token for update") + } + } + + // once the rows are locked (if forUpdate was true), we can query again using pop + if err := tx.Where("token = ?", token).First(refreshToken); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, nil, nil, RefreshTokenNotFoundError{} + } + return nil, nil, nil, errors.Wrap(err, "error finding refresh token") + } + + user, err := FindUserByID(tx, refreshToken.UserID) + if err != nil { + return nil, nil, nil, err + } + + var session *Session + + if refreshToken.SessionId != nil { + sessionId := *refreshToken.SessionId + + if sessionId != uuid.Nil { + session, err = FindSessionByID(tx, sessionId, forUpdate) + if err != nil { + if forUpdate { + return nil, nil, nil, err + } + + if !IsNotFoundError(err) { + return nil, nil, nil, errors.Wrap(err, "error finding session from refresh token") + } + + // otherwise, there's no session for this refresh token + } + } + } + + return user, refreshToken, session, nil +} + +// FindUsersInAudience finds users with the matching audience. +func FindUsersInAudience(tx *storage.Connection, aud string, pageParams *Pagination, sortParams *SortParams, filter string) ([]*User, error) { + users := []*User{} + q := tx.Q().Where("instance_id = ? and aud = ?", uuid.Nil, aud) + + if filter != "" { + lf := "%" + filter + "%" + // we must specify the collation in order to get case insensitive search for the JSON column + q = q.Where("(email LIKE ? OR raw_user_meta_data->>'full_name' ILIKE ?)", lf, lf) + } + + if sortParams != nil && len(sortParams.Fields) > 0 { + for _, field := range sortParams.Fields { + q = q.Order(field.Name + " " + string(field.Dir)) + } + } + + var err error + if pageParams != nil { + err = q.Paginate(int(pageParams.Page), int(pageParams.PerPage)).All(&users) // #nosec G115 + pageParams.Count = uint64(q.Paginator.TotalEntriesSize) // #nosec G115 + } else { + err = q.All(&users) + } + + return users, err +} + +// IsDuplicatedEmail returns whether a user exists with a matching email and audience. +// If a currentUser is provided, we will need to filter out any identities that belong to the current user. +func IsDuplicatedEmail(tx *storage.Connection, email, aud string, currentUser *User) (*User, error) { + var identities []Identity + + if err := tx.Eager().Q().Where("email = ?", strings.ToLower(email)).All(&identities); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, nil + } + + return nil, errors.Wrap(err, "unable to find identity by email for duplicates") + } + + userIDs := make(map[string]uuid.UUID) + for _, identity := range identities { + if _, ok := userIDs[identity.UserID.String()]; !ok { + if !identity.IsForSSOProvider() { + userIDs[identity.UserID.String()] = identity.UserID + } + } + } + + var currentUserId uuid.UUID + if currentUser != nil { + currentUserId = currentUser.ID + } + + for _, userID := range userIDs { + if userID != currentUserId { + user, err := FindUserByID(tx, userID) + if err != nil { + return nil, errors.Wrap(err, "unable to find user from email identity for duplicates") + } + if user.Aud == aud { + return user, nil + } + } + } + + // out of an abundance of caution, if nothing was found via the + // identities table we also do a final check on the users table + user, err := FindUserByEmailAndAudience(tx, email, aud) + if err != nil && !IsNotFoundError(err) { + return nil, errors.Wrap(err, "unable to find user email address for duplicates") + } + + return user, nil +} + +// IsDuplicatedPhone checks if the phone number already exists in the users table +func IsDuplicatedPhone(tx *storage.Connection, phone, aud string) (bool, error) { + _, err := FindUserByPhoneAndAudience(tx, phone, aud) + if err != nil { + if IsNotFoundError(err) { + return false, nil + } + return false, err + } + return true, nil +} + +// Ban a user for a given duration. +func (u *User) Ban(tx *storage.Connection, duration time.Duration) error { + if duration == time.Duration(0) { + u.BannedUntil = nil + } else { + t := time.Now().Add(duration) + u.BannedUntil = &t + } + return tx.UpdateOnly(u, "banned_until") +} + +// IsBanned checks if a user is banned or not +func (u *User) IsBanned() bool { + if u.BannedUntil == nil { + return false + } + return time.Now().Before(*u.BannedUntil) +} + +func (u *User) HasMFAEnabled() bool { + for _, factor := range u.Factors { + if factor.IsVerified() { + return true + } + } + + return false +} + +func (u *User) UpdateBannedUntil(tx *storage.Connection) error { + return tx.UpdateOnly(u, "banned_until") +} + +// RemoveUnconfirmedIdentities removes potentially malicious unconfirmed identities from a user (if any) +func (u *User) RemoveUnconfirmedIdentities(tx *storage.Connection, identity *Identity) error { + if identity.Provider != "email" && identity.Provider != "phone" { + // user is unconfirmed so the password should be reset + u.EncryptedPassword = nil + if terr := tx.UpdateOnly(u, "encrypted_password"); terr != nil { + return terr + } + } + + // user is unconfirmed so existing user_metadata should be overwritten + // to use the current identity metadata + u.UserMetaData = identity.IdentityData + if terr := u.UpdateUserMetaData(tx, u.UserMetaData); terr != nil { + return terr + } + + // finally, remove all identities except the current identity being authenticated + for i := range u.Identities { + if u.Identities[i].ID != identity.ID { + if terr := tx.Destroy(&u.Identities[i]); terr != nil { + return terr + } + } + } + + // user is unconfirmed so none of the providers associated to it are verified yet + // only the current provider should be kept + if terr := u.UpdateAppMetaDataProviders(tx); terr != nil { + return terr + } + return nil +} + +// SoftDeleteUser performs a soft deletion on the user by obfuscating and clearing certain fields +func (u *User) SoftDeleteUser(tx *storage.Connection) error { + u.Email = storage.NullString(obfuscateEmail(u, u.GetEmail())) + u.Phone = storage.NullString(obfuscatePhone(u, u.GetPhone())) + u.EmailChange = obfuscateEmail(u, u.EmailChange) + u.PhoneChange = obfuscatePhone(u, u.PhoneChange) + u.EncryptedPassword = nil + u.ConfirmationToken = "" + u.RecoveryToken = "" + u.EmailChangeTokenCurrent = "" + u.EmailChangeTokenNew = "" + u.PhoneChangeToken = "" + + // set deleted_at time + now := time.Now() + u.DeletedAt = &now + + if err := tx.UpdateOnly( + u, + "email", + "phone", + "encrypted_password", + "email_change", + "phone_change", + "confirmation_token", + "recovery_token", + "email_change_token_current", + "email_change_token_new", + "phone_change_token", + "deleted_at", + ); err != nil { + return err + } + + if err := ClearAllOneTimeTokensForUser(tx, u.ID); err != nil { + return err + } + + // set raw_user_meta_data to {} + userMetaDataUpdates := map[string]interface{}{} + for k := range u.UserMetaData { + userMetaDataUpdates[k] = nil + } + + if err := u.UpdateUserMetaData(tx, userMetaDataUpdates); err != nil { + return err + } + + // set raw_app_meta_data to {} + appMetaDataUpdates := map[string]interface{}{} + for k := range u.AppMetaData { + appMetaDataUpdates[k] = nil + } + + if err := u.UpdateAppMetaData(tx, appMetaDataUpdates); err != nil { + return err + } + + if err := Logout(tx, u.ID); err != nil { + return err + } + + return nil +} + +// SoftDeleteUserIdentities performs a soft deletion on all identities associated to a user +func (u *User) SoftDeleteUserIdentities(tx *storage.Connection) error { + identities, err := FindIdentitiesByUserID(tx, u.ID) + if err != nil { + return err + } + + // set identity_data to {} + for _, identity := range identities { + identityDataUpdates := map[string]interface{}{} + for k := range identity.IdentityData { + identityDataUpdates[k] = nil + } + if err := identity.UpdateIdentityData(tx, identityDataUpdates); err != nil { + return err + } + // updating the identity.ID has to happen last since the primary key is on (provider, id) + // we use RawQuery here instead of UpdateOnly because UpdateOnly relies on the primary key of Identity + if err := tx.RawQuery( + "update "+ + (&pop.Model{Value: Identity{}}).TableName()+ + " set provider_id = ? where id = ?", + obfuscateIdentityProviderId(identity), + identity.ID, + ).Exec(); err != nil { + return err + } + } + return nil +} + +func (u *User) FindOwnedFactorByID(tx *storage.Connection, factorID uuid.UUID) (*Factor, error) { + var factor Factor + err := tx.Where("user_id = ? AND id = ?", u.ID, factorID).First(&factor) + if err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, &FactorNotFoundError{} + } + return nil, err + } + return &factor, nil +} + +func (user *User) WebAuthnID() []byte { + return []byte(user.ID.String()) +} + +func (user *User) WebAuthnName() string { + return user.Email.String() +} + +func (user *User) WebAuthnDisplayName() string { + return user.Email.String() +} + +func (user *User) WebAuthnCredentials() []webauthn.Credential { + var credentials []webauthn.Credential + + for _, factor := range user.Factors { + if factor.IsVerified() && factor.FactorType == WebAuthn { + credential := factor.WebAuthnCredential.Credential + credentials = append(credentials, credential) + } + } + + return credentials +} + +func obfuscateValue(id uuid.UUID, value string) string { + hash := sha256.Sum256([]byte(id.String() + value)) + return base64.RawURLEncoding.EncodeToString(hash[:]) +} + +func obfuscateEmail(u *User, email string) string { + return obfuscateValue(u.ID, email) +} + +func obfuscatePhone(u *User, phone string) string { + // Field converted from VARCHAR(15) to text + return obfuscateValue(u.ID, phone)[:15] +} + +func obfuscateIdentityProviderId(identity *Identity) string { + return obfuscateValue(identity.UserID, identity.Provider+":"+identity.ProviderID) +} + +// FindUserByPhoneChangeAndAudience finds a user with the matching phone change and audience. +func FindUserByPhoneChangeAndAudience(tx *storage.Connection, phone, aud string) (*User, error) { + return findUser(tx, "instance_id = ? and phone_change = ? and aud = ? and is_sso_user = false", uuid.Nil, phone, aud) +} diff --git a/auth_v2.169.0/internal/models/user_test.go b/auth_v2.169.0/internal/models/user_test.go new file mode 100644 index 0000000..0349543 --- /dev/null +++ b/auth_v2.169.0/internal/models/user_test.go @@ -0,0 +1,467 @@ +package models + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/storage/test" + "golang.org/x/crypto/bcrypt" +) + +const modelsTestConfig = "../../hack/test.env" + +func init() { + crypto.PasswordHashCost = crypto.QuickHashCost +} + +type UserTestSuite struct { + suite.Suite + db *storage.Connection +} + +func (ts *UserTestSuite) SetupTest() { + TruncateAll(ts.db) +} + +func TestUser(t *testing.T) { + globalConfig, err := conf.LoadGlobal(modelsTestConfig) + require.NoError(t, err) + + conn, err := test.SetupDBConnection(globalConfig) + require.NoError(t, err) + + ts := &UserTestSuite{ + db: conn, + } + defer ts.db.Close() + + suite.Run(t, ts) +} + +func (ts *UserTestSuite) TestUpdateAppMetadata() { + u, err := NewUser("", "", "", "", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), u.UpdateAppMetaData(ts.db, make(map[string]interface{}))) + + require.NotNil(ts.T(), u.AppMetaData) + + require.NoError(ts.T(), u.UpdateAppMetaData(ts.db, map[string]interface{}{ + "foo": "bar", + })) + + require.Equal(ts.T(), "bar", u.AppMetaData["foo"]) + require.NoError(ts.T(), u.UpdateAppMetaData(ts.db, map[string]interface{}{ + "foo": nil, + })) + require.Len(ts.T(), u.AppMetaData, 0) + require.Equal(ts.T(), nil, u.AppMetaData["foo"]) +} + +func (ts *UserTestSuite) TestUpdateUserMetadata() { + u, err := NewUser("", "", "", "", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), u.UpdateUserMetaData(ts.db, make(map[string]interface{}))) + + require.NotNil(ts.T(), u.UserMetaData) + + require.NoError(ts.T(), u.UpdateUserMetaData(ts.db, map[string]interface{}{ + "foo": "bar", + })) + + require.Equal(ts.T(), "bar", u.UserMetaData["foo"]) + require.NoError(ts.T(), u.UpdateUserMetaData(ts.db, map[string]interface{}{ + "foo": nil, + })) + require.Len(ts.T(), u.UserMetaData, 0) + require.Equal(ts.T(), nil, u.UserMetaData["foo"]) +} + +func (ts *UserTestSuite) TestFindUserByConfirmationToken() { + u := ts.createUser() + tokenHash := "test_confirmation_token" + require.NoError(ts.T(), CreateOneTimeToken(ts.db, u.ID, "relates_to not used", tokenHash, ConfirmationToken)) + + n, err := FindUserByConfirmationToken(ts.db, tokenHash) + require.NoError(ts.T(), err) + require.Equal(ts.T(), u.ID, n.ID) +} + +func (ts *UserTestSuite) TestFindUserByEmailAndAudience() { + u := ts.createUser() + + n, err := FindUserByEmailAndAudience(ts.db, u.GetEmail(), "test") + require.NoError(ts.T(), err) + require.Equal(ts.T(), u.ID, n.ID) + + _, err = FindUserByEmailAndAudience(ts.db, u.GetEmail(), "invalid") + require.EqualError(ts.T(), err, UserNotFoundError{}.Error()) +} + +func (ts *UserTestSuite) TestFindUsersInAudience() { + u := ts.createUser() + + n, err := FindUsersInAudience(ts.db, u.Aud, nil, nil, "") + require.NoError(ts.T(), err) + require.Len(ts.T(), n, 1) + + p := Pagination{ + Page: 1, + PerPage: 50, + } + n, err = FindUsersInAudience(ts.db, u.Aud, &p, nil, "") + require.NoError(ts.T(), err) + require.Len(ts.T(), n, 1) + assert.Equal(ts.T(), uint64(1), p.Count) + + sp := &SortParams{ + Fields: []SortField{ + {Name: "created_at", Dir: Descending}, + }, + } + n, err = FindUsersInAudience(ts.db, u.Aud, nil, sp, "") + require.NoError(ts.T(), err) + require.Len(ts.T(), n, 1) +} + +func (ts *UserTestSuite) TestFindUserByID() { + u := ts.createUser() + + n, err := FindUserByID(ts.db, u.ID) + require.NoError(ts.T(), err) + require.Equal(ts.T(), u.ID, n.ID) +} + +func (ts *UserTestSuite) TestFindUserByRecoveryToken() { + u := ts.createUser() + tokenHash := "test_recovery_token" + require.NoError(ts.T(), CreateOneTimeToken(ts.db, u.ID, "relates_to not used", tokenHash, RecoveryToken)) + + n, err := FindUserByRecoveryToken(ts.db, tokenHash) + require.NoError(ts.T(), err) + require.Equal(ts.T(), u.ID, n.ID) +} + +func (ts *UserTestSuite) TestFindUserWithRefreshToken() { + u := ts.createUser() + r, err := GrantAuthenticatedUser(ts.db, u, GrantParams{}) + require.NoError(ts.T(), err) + + n, nr, s, err := FindUserWithRefreshToken(ts.db, r.Token, true /* forUpdate */) + require.NoError(ts.T(), err) + require.Equal(ts.T(), r.ID, nr.ID) + require.Equal(ts.T(), u.ID, n.ID) + require.NotNil(ts.T(), s) + require.Equal(ts.T(), *r.SessionId, s.ID) +} + +func (ts *UserTestSuite) TestIsDuplicatedEmail() { + _ = ts.createUserWithEmail("david.calavera@netlify.com") + + e, err := IsDuplicatedEmail(ts.db, "david.calavera@netlify.com", "test", nil) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), e, "expected email to be duplicated") + + e, err = IsDuplicatedEmail(ts.db, "davidcalavera@netlify.com", "test", nil) + require.NoError(ts.T(), err) + require.Nil(ts.T(), e, "expected email to not be duplicated", nil) + + e, err = IsDuplicatedEmail(ts.db, "david@netlify.com", "test", nil) + require.NoError(ts.T(), err) + require.Nil(ts.T(), e, "expected same email to not be duplicated", nil) + + e, err = IsDuplicatedEmail(ts.db, "david.calavera@netlify.com", "other-aud", nil) + require.NoError(ts.T(), err) + require.Nil(ts.T(), e, "expected same email to not be duplicated") +} + +func (ts *UserTestSuite) createUser() *User { + return ts.createUserWithEmail("david@netlify.com") +} + +func (ts *UserTestSuite) createUserWithEmail(email string) *User { + user, err := NewUser("", email, "secret", "test", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(user)) + + identity, err := NewIdentity(user, "email", map[string]interface{}{ + "sub": user.ID.String(), + "email": email, + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(identity)) + + return user +} + +func (ts *UserTestSuite) TestRemoveUnconfirmedIdentities() { + user, err := NewUser("+29382983298", "someone@example.com", "abcdefgh", "authenticated", nil) + require.NoError(ts.T(), err) + + user.AppMetaData = map[string]interface{}{ + "provider": "email", + "providers": []string{"email", "phone", "twitter"}, + } + + require.NoError(ts.T(), ts.db.Create(user)) + + idEmail, err := NewIdentity(user, "email", map[string]interface{}{ + "sub": "someone@example.com", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(idEmail)) + + idPhone, err := NewIdentity(user, "phone", map[string]interface{}{ + "sub": "+29382983298", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(idPhone)) + + idTwitter, err := NewIdentity(user, "twitter", map[string]interface{}{ + "sub": "test_twitter_user_id", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(idTwitter)) + + user.Identities = append(user.Identities, *idEmail, *idPhone, *idTwitter) + + // reload the user + require.NoError(ts.T(), ts.db.Load(user)) + + require.False(ts.T(), user.IsConfirmed(), "user's email must not be confirmed") + + require.NoError(ts.T(), user.RemoveUnconfirmedIdentities(ts.db, idTwitter)) + + // reload the user to check that identities are deleted from the db too + require.NoError(ts.T(), ts.db.Load(user)) + require.Empty(ts.T(), user.EncryptedPassword, "password still remains in user") + + require.Len(ts.T(), user.Identities, 1, "only one identity must be remaining") + require.Equal(ts.T(), idTwitter.ID, user.Identities[0].ID, "remaining identity is not the expected one") + + require.NotNil(ts.T(), user.AppMetaData) + require.Equal(ts.T(), user.AppMetaData["provider"], "twitter") + require.Equal(ts.T(), user.AppMetaData["providers"], []string{"twitter"}) +} + +func (ts *UserTestSuite) TestConfirmEmailChange() { + user, err := NewUser("", "test@example.com", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(user)) + + identity, err := NewIdentity(user, "email", map[string]interface{}{ + "sub": user.ID.String(), + "email": "test@example.com", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(identity)) + + user.EmailChange = "new@example.com" + require.NoError(ts.T(), ts.db.UpdateOnly(user, "email_change")) + + require.NoError(ts.T(), user.ConfirmEmailChange(ts.db, 0)) + + require.NoError(ts.T(), ts.db.Eager().Load(user)) + identity, err = FindIdentityByIdAndProvider(ts.db, user.ID.String(), "email") + require.NoError(ts.T(), err) + + require.Equal(ts.T(), user.Email, storage.NullString("new@example.com")) + require.Equal(ts.T(), user.EmailChange, "") + + require.NotNil(ts.T(), identity.IdentityData) + require.Equal(ts.T(), identity.IdentityData["email"], "new@example.com") +} + +func (ts *UserTestSuite) TestConfirmPhoneChange() { + user, err := NewUser("123456789", "", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(user)) + + identity, err := NewIdentity(user, "phone", map[string]interface{}{ + "sub": user.ID.String(), + "phone": "123456789", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(identity)) + + user.PhoneChange = "987654321" + require.NoError(ts.T(), ts.db.UpdateOnly(user, "phone_change")) + + require.NoError(ts.T(), user.ConfirmPhoneChange(ts.db)) + + require.NoError(ts.T(), ts.db.Eager().Load(user)) + identity, err = FindIdentityByIdAndProvider(ts.db, user.ID.String(), "phone") + require.NoError(ts.T(), err) + + require.Equal(ts.T(), user.Phone, storage.NullString("987654321")) + require.Equal(ts.T(), user.PhoneChange, "") + + require.NotNil(ts.T(), identity.IdentityData) + require.Equal(ts.T(), identity.IdentityData["phone"], "987654321") +} + +func (ts *UserTestSuite) TestUpdateUserEmailSuccess() { + userA, err := NewUser("", "foo@example.com", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(userA)) + + primaryIdentity, err := NewIdentity(userA, "email", map[string]interface{}{ + "sub": userA.ID.String(), + "email": "foo@example.com", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(primaryIdentity)) + + secondaryIdentity, err := NewIdentity(userA, "google", map[string]interface{}{ + "sub": userA.ID.String(), + "email": "bar@example.com", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(secondaryIdentity)) + + // UpdateUserEmail should not do anything and the user's email should still use the primaryIdentity + require.NoError(ts.T(), userA.UpdateUserEmailFromIdentities(ts.db)) + require.Equal(ts.T(), primaryIdentity.GetEmail(), userA.GetEmail()) + + // remove primary identity + require.NoError(ts.T(), ts.db.Destroy(primaryIdentity)) + + // UpdateUserEmail should update the user to use the secondary identity's email + require.NoError(ts.T(), userA.UpdateUserEmailFromIdentities(ts.db)) + require.Equal(ts.T(), secondaryIdentity.GetEmail(), userA.GetEmail()) +} + +func (ts *UserTestSuite) TestUpdateUserEmailFailure() { + userA, err := NewUser("", "foo@example.com", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(userA)) + + primaryIdentity, err := NewIdentity(userA, "email", map[string]interface{}{ + "sub": userA.ID.String(), + "email": "foo@example.com", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(primaryIdentity)) + + secondaryIdentity, err := NewIdentity(userA, "google", map[string]interface{}{ + "sub": userA.ID.String(), + "email": "bar@example.com", + }) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(secondaryIdentity)) + + userB, err := NewUser("", "bar@example.com", "", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(userB)) + + // remove primary identity + require.NoError(ts.T(), ts.db.Destroy(primaryIdentity)) + + // UpdateUserEmail should fail with the email unique constraint violation error + // since userB is using the secondary identity's email + require.ErrorIs(ts.T(), userA.UpdateUserEmailFromIdentities(ts.db), UserEmailUniqueConflictError{}) + require.Equal(ts.T(), primaryIdentity.GetEmail(), userA.GetEmail()) +} + +func (ts *UserTestSuite) TestNewUserWithPasswordHashSuccess() { + cases := []struct { + desc string + hash string + }{ + { + desc: "Valid bcrypt hash", + hash: "$2y$10$SXEz2HeT8PUIGQXo9yeUIem8KzNxgG0d7o/.eGj2rj8KbRgAuRVlq", + }, + { + desc: "Valid argon2i hash", + hash: "$argon2i$v=19$m=16,t=2,p=1$bGJRWThNOHJJTVBSdHl2dQ$NfEnUOuUpb7F2fQkgFUG4g", + }, + { + desc: "Valid argon2id hash", + hash: "$argon2id$v=19$m=32,t=3,p=2$SFVpOWJ0eXhjRzVkdGN1RQ$RXnb8rh7LaDcn07xsssqqulZYXOM/EUCEFMVcAcyYVk", + }, + { + desc: "Valid Firebase scrypt hash", + hash: "$fbscrypt$v=1,n=14,r=8,p=1,ss=Bw==,sk=ou9tdYTGyYm8kuR6Dt0Bp0kDuAYoXrK16mbZO4yGwAn3oLspjnN0/c41v8xZnO1n14J3MjKj1b2g6AUCAlFwMw==$C0sHCg9ek77hsg==$ZGlmZmVyZW50aGFzaA==", + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + u, err := NewUserWithPasswordHash("", "", c.hash, "", nil) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), u) + }) + } +} + +func (ts *UserTestSuite) TestNewUserWithPasswordHashFailure() { + cases := []struct { + desc string + hash string + }{ + { + desc: "Invalid argon2i hash", + hash: "$argon2id$test", + }, + { + desc: "Invalid bcrypt hash", + hash: "plaintest_password", + }, + { + desc: "Invalid scrypt hash", + hash: "$fbscrypt$invalid", + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + u, err := NewUserWithPasswordHash("", "", c.hash, "", nil) + require.Error(ts.T(), err) + require.Nil(ts.T(), u) + }) + } +} + +func (ts *UserTestSuite) TestAuthenticate() { + // every case uses "test" as the password + cases := []struct { + desc string + hash string + expectedHashCost int + }{ + { + desc: "Invalid bcrypt hash cost of 11", + hash: "$2y$11$4lH57PU7bGATpRcx93vIoObH3qDmft/pytbOzDG9/1WsyNmN5u4di", + expectedHashCost: bcrypt.MinCost, + }, + { + desc: "Valid bcrypt hash cost of 10", + hash: "$2y$10$va66S4MxFrH6G6L7BzYl0.QgcYgvSr/F92gc.3botlz7bG4p/g/1i", + expectedHashCost: bcrypt.DefaultCost, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + u, err := NewUserWithPasswordHash("", "", c.hash, "", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.db.Create(u)) + require.NotNil(ts.T(), u) + + isAuthenticated, _, err := u.Authenticate(context.Background(), ts.db, "test", nil, false, "") + require.NoError(ts.T(), err) + require.True(ts.T(), isAuthenticated) + + // check hash cost + hashCost, err := bcrypt.Cost([]byte(*u.EncryptedPassword)) + require.NoError(ts.T(), err) + require.Equal(ts.T(), c.expectedHashCost, hashCost) + }) + } +} diff --git a/auth_v2.169.0/internal/observability/cleanup.go b/auth_v2.169.0/internal/observability/cleanup.go new file mode 100644 index 0000000..2e88c35 --- /dev/null +++ b/auth_v2.169.0/internal/observability/cleanup.go @@ -0,0 +1,18 @@ +package observability + +import ( + "context" + "sync" + + "github.com/supabase/auth/internal/utilities" +) + +var ( + cleanupWaitGroup sync.WaitGroup +) + +// WaitForCleanup waits until all observability long-running goroutines shut +// down cleanly or until the provided context signals done. +func WaitForCleanup(ctx context.Context) { + utilities.WaitForCleanup(ctx, &cleanupWaitGroup) +} diff --git a/auth_v2.169.0/internal/observability/logging.go b/auth_v2.169.0/internal/observability/logging.go new file mode 100644 index 0000000..ff8ac96 --- /dev/null +++ b/auth_v2.169.0/internal/observability/logging.go @@ -0,0 +1,125 @@ +package observability + +import ( + "os" + "sync" + "time" + + "github.com/bombsimon/logrusr/v3" + "github.com/gobuffalo/pop/v6" + "github.com/gobuffalo/pop/v6/logging" + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" + "go.opentelemetry.io/otel" +) + +const ( + LOG_SQL_ALL = "all" + LOG_SQL_NONE = "none" + LOG_SQL_STATEMENT = "statement" +) + +var ( + loggingOnce sync.Once +) + +type CustomFormatter struct { + logrus.JSONFormatter +} + +func NewCustomFormatter() *CustomFormatter { + return &CustomFormatter{ + JSONFormatter: logrus.JSONFormatter{ + DisableTimestamp: false, + TimestampFormat: time.RFC3339, + }, + } +} + +func (f *CustomFormatter) Format(entry *logrus.Entry) ([]byte, error) { + // logrus doesn't support formatting the time in UTC so we need to use a custom formatter + entry.Time = entry.Time.UTC() + return f.JSONFormatter.Format(entry) +} + +func ConfigureLogging(config *conf.LoggingConfig) error { + var err error + + loggingOnce.Do(func() { + formatter := NewCustomFormatter() + logrus.SetFormatter(formatter) + + // use a file if you want + if config.File != "" { + f, errOpen := os.OpenFile(config.File, os.O_RDWR|os.O_APPEND|os.O_CREATE, 0660) //#nosec G302 -- Log files should be rw-rw-r-- + if errOpen != nil { + err = errOpen + return + } + logrus.SetOutput(f) + logrus.Infof("Set output file to %s", config.File) + } + + if config.Level != "" { + level, errParse := logrus.ParseLevel(config.Level) + if err != nil { + err = errParse + return + } + logrus.SetLevel(level) + logrus.Debug("Set log level to: " + logrus.GetLevel().String()) + } + + f := logrus.Fields{} + for k, v := range config.Fields { + f[k] = v + } + logrus.WithFields(f) + + setPopLogger(config.SQL) + + otel.SetLogger(logrusr.New(logrus.StandardLogger().WithField("component", "otel"))) + }) + + return err +} + +func setPopLogger(sql string) { + popLog := logrus.WithField("component", "pop") + sqlLog := logrus.WithField("component", "sql") + + shouldLogSQL := sql == LOG_SQL_STATEMENT || sql == LOG_SQL_ALL + shouldLogSQLArgs := sql == LOG_SQL_ALL + + pop.SetLogger(func(lvl logging.Level, s string, args ...interface{}) { + // Special case SQL logging since we have 2 extra flags to check + if lvl == logging.SQL { + if !shouldLogSQL { + return + } + + if shouldLogSQLArgs && len(args) > 0 { + sqlLog.WithField("args", args).Info(s) + } else { + sqlLog.Info(s) + } + return + } + + l := popLog + if len(args) > 0 { + l = l.WithField("args", args) + } + + switch lvl { + case logging.SQL, logging.Debug: + l.Debug(s) + case logging.Info: + l.Info(s) + case logging.Warn: + l.Warn(s) + case logging.Error: + l.Error(s) + } + }) +} diff --git a/auth_v2.169.0/internal/observability/metrics.go b/auth_v2.169.0/internal/observability/metrics.go new file mode 100644 index 0000000..b3632aa --- /dev/null +++ b/auth_v2.169.0/internal/observability/metrics.go @@ -0,0 +1,202 @@ +package observability + +import ( + "context" + "fmt" + "net" + "net/http" + "sync" + "time" + + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" + + "github.com/prometheus/client_golang/prometheus/promhttp" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc" + "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp" + "go.opentelemetry.io/otel/exporters/prometheus" + "go.opentelemetry.io/otel/metric" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + + otelruntimemetrics "go.opentelemetry.io/contrib/instrumentation/runtime" +) + +func Meter(instrumentationName string, opts ...metric.MeterOption) metric.Meter { + return otel.Meter(instrumentationName, opts...) +} + +func ObtainMetricCounter(name, desc string) metric.Int64Counter { + counter, err := Meter("gotrue").Int64Counter(name, metric.WithDescription(desc)) + if err != nil { + panic(err) + } + return counter +} + +func enablePrometheusMetrics(ctx context.Context, mc *conf.MetricsConfig) error { + exporter, err := prometheus.New() + if err != nil { + return err + } + + provider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(exporter)) + + otel.SetMeterProvider(provider) + + cleanupWaitGroup.Add(1) + go func() { + addr := net.JoinHostPort(mc.PrometheusListenHost, mc.PrometheusListenPort) + baseContext, cancel := context.WithCancel(context.Background()) + + server := &http.Server{ + Addr: addr, + Handler: promhttp.Handler(), + BaseContext: func(net.Listener) context.Context { + return baseContext + }, + ReadHeaderTimeout: 2 * time.Second, // to mitigate a Slowloris attack + } + + go func() { + defer cleanupWaitGroup.Done() + <-ctx.Done() + + cancel() // close baseContext + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + + if err := server.Shutdown(shutdownCtx); err != nil { + logrus.WithError(err).Errorf("prometheus server (%s) failed to gracefully shut down", addr) + } + }() + + logrus.Infof("prometheus server listening on %s", addr) + + if err := server.ListenAndServe(); err != nil { + logrus.WithError(err).Errorf("prometheus server (%s) shut down", addr) + } else { + logrus.Info("prometheus metric exporter shut down") + } + }() + + return nil +} + +func enableOpenTelemetryMetrics(ctx context.Context, mc *conf.MetricsConfig) error { + switch mc.ExporterProtocol { + case "grpc": + metricExporter, err := otlpmetricgrpc.New(ctx) + if err != nil { + return err + } + meterProvider := sdkmetric.NewMeterProvider( + sdkmetric.WithReader(sdkmetric.NewPeriodicReader(metricExporter)), + ) + + otel.SetMeterProvider(meterProvider) + + cleanupWaitGroup.Add(1) + go func() { + defer cleanupWaitGroup.Done() + + <-ctx.Done() + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + + if err := metricExporter.Shutdown(shutdownCtx); err != nil { + logrus.WithError(err).Error("unable to gracefully shut down OpenTelemetry metric exporter") + } else { + logrus.Info("OpenTelemetry metric exporter shut down") + } + }() + + case "http/protobuf": + metricExporter, err := otlpmetrichttp.New(ctx) + if err != nil { + return err + } + meterProvider := sdkmetric.NewMeterProvider( + sdkmetric.WithReader(sdkmetric.NewPeriodicReader(metricExporter)), + ) + + otel.SetMeterProvider(meterProvider) + + cleanupWaitGroup.Add(1) + go func() { + defer cleanupWaitGroup.Done() + + <-ctx.Done() + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + + if err := metricExporter.Shutdown(shutdownCtx); err != nil { + logrus.WithError(err).Error("unable to gracefully shut down OpenTelemetry metric exporter") + } else { + logrus.Info("OpenTelemetry metric exporter shut down") + } + }() + + default: // http/json for example + return fmt.Errorf("unsupported OpenTelemetry exporter protocol %q", mc.ExporterProtocol) + } + logrus.Info("OpenTelemetry metrics exporter started") + return nil + +} + +var ( + metricsOnce *sync.Once = &sync.Once{} +) + +func ConfigureMetrics(ctx context.Context, mc *conf.MetricsConfig) error { + if ctx == nil { + panic("context must not be nil") + } + + var err error + + metricsOnce.Do(func() { + if mc.Enabled { + switch mc.Exporter { + case conf.Prometheus: + if err = enablePrometheusMetrics(ctx, mc); err != nil { + logrus.WithError(err).Error("unable to start prometheus metrics exporter") + return + } + + case conf.OpenTelemetryMetrics: + if err = enableOpenTelemetryMetrics(ctx, mc); err != nil { + logrus.WithError(err).Error("unable to start OTLP metrics exporter") + + return + } + } + } + + if err := otelruntimemetrics.Start(otelruntimemetrics.WithMinimumReadMemStatsInterval(time.Second)); err != nil { + logrus.WithError(err).Error("unable to start OpenTelemetry Go runtime metrics collection") + } else { + logrus.Info("Go runtime metrics collection started") + } + + meter := otel.Meter("gotrue") + _, err := meter.Int64ObservableGauge( + "gotrue_running", + metric.WithDescription("Whether GoTrue is running (always 1)"), + metric.WithInt64Callback(func(_ context.Context, obsrv metric.Int64Observer) error { + obsrv.Observe(int64(1)) + return nil + }), + ) + if err != nil { + logrus.WithError(err).Error("unable to get gotrue.gotrue_running gague metric") + return + } + }) + + return err +} diff --git a/auth_v2.169.0/internal/observability/profiler.go b/auth_v2.169.0/internal/observability/profiler.go new file mode 100644 index 0000000..71acc11 --- /dev/null +++ b/auth_v2.169.0/internal/observability/profiler.go @@ -0,0 +1,87 @@ +package observability + +import ( + "context" + "net" + "time" + + "net/http" + "net/http/pprof" + + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" +) + +func ConfigureProfiler(ctx context.Context, pc *conf.ProfilerConfig) error { + if !pc.Enabled { + return nil + } + addr := net.JoinHostPort(pc.Host, pc.Port) + baseContext, cancel := context.WithCancel(context.Background()) + cleanupWaitGroup.Add(1) + go func() { + server := &http.Server{ + Addr: addr, + Handler: &ProfilerHandler{}, + BaseContext: func(net.Listener) context.Context { + return baseContext + }, + ReadHeaderTimeout: 2 * time.Second, + } + + go func() { + defer cleanupWaitGroup.Done() + <-ctx.Done() + + cancel() // close baseContext + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + + if err := server.Shutdown(shutdownCtx); err != nil { + logrus.WithError(err).Errorf("profiler server (%s) failed to gracefully shut down", addr) + } + }() + + logrus.Infof("Profiler is listening on %s", addr) + + if err := server.ListenAndServe(); err != nil { + logrus.WithError(err).Errorf("profiler server (%s) shut down", addr) + } else { + logrus.Info("profiler shut down") + } + }() + + return nil +} + +type ProfilerHandler struct{} + +func (p *ProfilerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/debug/pprof/": + pprof.Index(w, r) + case "/debug/pprof/cmdline": + pprof.Cmdline(w, r) + case "/debug/pprof/profile": + pprof.Profile(w, r) + case "/debug/pprof/symbol": + pprof.Symbol(w, r) + case "/debug/pprof/trace": + pprof.Trace(w, r) + case "/debug/pprof/goroutine": + pprof.Handler("goroutine").ServeHTTP(w, r) + case "/debug/pprof/heap": + pprof.Handler("heap").ServeHTTP(w, r) + case "/debug/pprof/allocs": + pprof.Handler("allocs").ServeHTTP(w, r) + case "/debug/pprof/threadcreate": + pprof.Handler("threadcreate").ServeHTTP(w, r) + case "/debug/pprof/block": + pprof.Handler("block").ServeHTTP(w, r) + case "/debug/pprof/mutex": + pprof.Handler("mutex").ServeHTTP(w, r) + default: + http.NotFound(w, r) + } +} diff --git a/auth_v2.169.0/internal/observability/request-logger.go b/auth_v2.169.0/internal/observability/request-logger.go new file mode 100644 index 0000000..6eeffd6 --- /dev/null +++ b/auth_v2.169.0/internal/observability/request-logger.go @@ -0,0 +1,114 @@ +package observability + +import ( + "fmt" + "net/http" + "time" + + chimiddleware "github.com/go-chi/chi/v5/middleware" + "github.com/gofrs/uuid" + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/utilities" +) + +func AddRequestID(globalConfig *conf.GlobalConfiguration) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + id := uuid.Must(uuid.NewV4()).String() + if globalConfig.API.RequestIDHeader != "" { + id = r.Header.Get(globalConfig.API.RequestIDHeader) + } + ctx := r.Context() + ctx = utilities.WithRequestID(ctx, id) + next.ServeHTTP(w, r.WithContext(ctx)) + } + return http.HandlerFunc(fn) + } +} + +func NewStructuredLogger(logger *logrus.Logger, config *conf.GlobalConfiguration) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/health" { + next.ServeHTTP(w, r) + } else { + chimiddleware.RequestLogger(&structuredLogger{logger, config})(next).ServeHTTP(w, r) + } + }) + } +} + +type structuredLogger struct { + Logger *logrus.Logger + Config *conf.GlobalConfiguration +} + +func (l *structuredLogger) NewLogEntry(r *http.Request) chimiddleware.LogEntry { + referrer := utilities.GetReferrer(r, l.Config) + e := &logEntry{Entry: logrus.NewEntry(l.Logger)} + logFields := logrus.Fields{ + "component": "api", + "method": r.Method, + "path": r.URL.Path, + "remote_addr": utilities.GetIPAddress(r), + "referer": referrer, + } + + if reqID := utilities.GetRequestID(r.Context()); reqID != "" { + logFields["request_id"] = reqID + } + + e.Entry = e.Entry.WithFields(logFields) + return e +} + +// logEntry implements the chiMiddleware.LogEntry interface +type logEntry struct { + Entry *logrus.Entry +} + +func (e *logEntry) Write(status, bytes int, header http.Header, elapsed time.Duration, extra interface{}) { + fields := logrus.Fields{ + "status": status, + "duration": elapsed.Nanoseconds(), + } + + errorCode := header.Get("x-sb-error-code") + if errorCode != "" { + fields["error_code"] = errorCode + } + + entry := e.Entry.WithFields(fields) + entry.Info("request completed") + e.Entry = entry +} + +func (e *logEntry) Panic(v interface{}, stack []byte) { + entry := e.Entry.WithFields(logrus.Fields{ + "stack": string(stack), + "panic": fmt.Sprintf("%+v", v), + }) + entry.Error("request panicked") + e.Entry = entry +} + +func GetLogEntry(r *http.Request) *logEntry { + l, _ := chimiddleware.GetLogEntry(r).(*logEntry) + if l == nil { + return &logEntry{Entry: logrus.NewEntry(logrus.StandardLogger())} + } + return l +} + +func LogEntrySetField(r *http.Request, key string, value interface{}) { + if l, ok := r.Context().Value(chimiddleware.LogEntryCtxKey).(*logEntry); ok { + l.Entry = l.Entry.WithField(key, value) + } +} + +func LogEntrySetFields(r *http.Request, fields logrus.Fields) { + if l, ok := r.Context().Value(chimiddleware.LogEntryCtxKey).(*logEntry); ok { + l.Entry = l.Entry.WithFields(fields) + } +} diff --git a/auth_v2.169.0/internal/observability/request-logger_test.go b/auth_v2.169.0/internal/observability/request-logger_test.go new file mode 100644 index 0000000..7ab244c --- /dev/null +++ b/auth_v2.169.0/internal/observability/request-logger_test.go @@ -0,0 +1,72 @@ +package observability + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/conf" +) + +const apiTestConfig = "../../hack/test.env" + +func TestLogger(t *testing.T) { + var logBuffer bytes.Buffer + config, err := conf.LoadGlobal(apiTestConfig) + require.NoError(t, err) + + config.Logging.Level = "info" + require.NoError(t, ConfigureLogging(&config.Logging)) + + // logrus should write to the buffer so we can check if the logs are output correctly + logrus.SetOutput(&logBuffer) + + // add request id header + config.API.RequestIDHeader = "X-Request-ID" + addRequestIdHandler := AddRequestID(config) + + logHandler := NewStructuredLogger(logrus.StandardLogger(), config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, "http://example.com/path", nil) + req.Header.Add("X-Request-ID", "test-request-id") + require.NoError(t, err) + addRequestIdHandler(logHandler).ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + + var logs map[string]interface{} + require.NoError(t, json.NewDecoder(&logBuffer).Decode(&logs)) + require.Equal(t, "api", logs["component"]) + require.Equal(t, http.MethodPost, logs["method"]) + require.Equal(t, "/path", logs["path"]) + require.Equal(t, "test-request-id", logs["request_id"]) + require.NotNil(t, logs["time"]) +} + +func TestExcludeHealthFromLogs(t *testing.T) { + var logBuffer bytes.Buffer + config, err := conf.LoadGlobal(apiTestConfig) + require.NoError(t, err) + + config.Logging.Level = "info" + require.NoError(t, ConfigureLogging(&config.Logging)) + + // logrus should write to the buffer so we can check if the logs are output correctly + logrus.SetOutput(&logBuffer) + + logHandler := NewStructuredLogger(logrus.StandardLogger(), config)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("ok")) + })) + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "http://example.com/health", nil) + require.NoError(t, err) + logHandler.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + + require.Empty(t, logBuffer) +} diff --git a/auth_v2.169.0/internal/observability/request-tracing.go b/auth_v2.169.0/internal/observability/request-tracing.go new file mode 100644 index 0000000..e8ee61b --- /dev/null +++ b/auth_v2.169.0/internal/observability/request-tracing.go @@ -0,0 +1,170 @@ +package observability + +import ( + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/sirupsen/logrus" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + semconv "go.opentelemetry.io/otel/semconv/v1.25.0" + "go.opentelemetry.io/otel/trace" +) + +// traceChiRoutesSafely attempts to extract the Chi RouteContext. If the +// request does not have a RouteContext it will recover from the panic and +// attempt to figure out the route from the URL's path. +func traceChiRoutesSafely(r *http.Request) { + defer func() { + if rec := recover(); rec != nil { + logrus.WithField("error", rec).Error("unable to trace chi routes, traces may be off") + + span := trace.SpanFromContext(r.Context()) + span.SetAttributes(semconv.HTTPRouteKey.String(r.URL.Path)) + } + }() + + routeContext := chi.RouteContext(r.Context()) + span := trace.SpanFromContext(r.Context()) + span.SetAttributes(semconv.HTTPRouteKey.String(routeContext.RoutePattern())) +} + +// traceChiRouteURLParamsSafely attempts to extract the Chi RouteContext +// URLParams values for the route and assign them to the tracing span. If the +// request does not have a RouteContext it will recover from the panic and not +// set any params. +func traceChiRouteURLParamsSafely(r *http.Request) { + defer func() { + if rec := recover(); rec != nil { + logrus.WithField("error", rec).Error("unable to trace route with route params, traces may be off") + } + }() + + routeContext := chi.RouteContext(r.Context()) + span := trace.SpanFromContext(r.Context()) + + var attributes []attribute.KeyValue + + for i := 0; i < len(routeContext.URLParams.Keys); i += 1 { + key := routeContext.URLParams.Keys[i] + value := routeContext.URLParams.Values[i] + + attributes = append(attributes, attribute.String("http.route.param."+key, value)) + } + + if len(attributes) > 0 { + span.SetAttributes(attributes...) + } +} + +type interceptingResponseWriter struct { + writer http.ResponseWriter + + statusCode int +} + +func (w *interceptingResponseWriter) WriteHeader(statusCode int) { + w.statusCode = statusCode + + w.writer.WriteHeader(statusCode) +} + +func (w *interceptingResponseWriter) Write(data []byte) (int, error) { + return w.writer.Write(data) +} + +func (w *interceptingResponseWriter) Header() http.Header { + return w.writer.Header() +} + +// countStatusCodesSafely counts the number of HTTP status codes per route that +// occurred while GoTrue was running. If it is not able to identify the route +// via chi.RouteContext(ctx).RoutePattern() it counts with a noroute attribute. +func countStatusCodesSafely(w *interceptingResponseWriter, r *http.Request, counter metric.Int64Counter) { + if counter == nil { + return + } + + defer func() { + if rec := recover(); rec != nil { + logrus.WithField("error", rec).Error("unable to count status codes safely, metrics may be off") + counter.Add( + r.Context(), + 1, + metric.WithAttributes( + attribute.Bool("noroute", true), + attribute.Int("code", w.statusCode)), + ) + } + }() + + ctx := r.Context() + + routeContext := chi.RouteContext(ctx) + routePattern := semconv.HTTPRouteKey.String(routeContext.RoutePattern()) + + counter.Add( + ctx, + 1, + metric.WithAttributes(attribute.Int("code", w.statusCode), routePattern), + ) +} + +// RequestTracing returns an HTTP handler that traces all HTTP requests coming +// in. Supports Chi routers, so this should be one of the first middlewares on +// the router. +func RequestTracing() func(http.Handler) http.Handler { + meter := otel.Meter("gotrue") + statusCodes, err := meter.Int64Counter( + "http_status_codes", + metric.WithDescription("Number of returned HTTP status codes"), + ) + if err != nil { + logrus.WithError(err).Error("unable to get gotrue.http_status_codes counter metric") + } + + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + writer := interceptingResponseWriter{ + writer: w, + } + + defer traceChiRoutesSafely(r) + defer traceChiRouteURLParamsSafely(r) + defer countStatusCodesSafely(&writer, r, statusCodes) + + originalUserAgent := r.Header.Get("X-Gotrue-Original-User-Agent") + if originalUserAgent != "" { + r.Header.Set("User-Agent", originalUserAgent) + } + + next.ServeHTTP(&writer, r) + + if originalUserAgent != "" { + r.Header.Set("X-Gotrue-Original-User-Agent", originalUserAgent) + r.Header.Set("User-Agent", "stripped") + } + } + + otelHandler := otelhttp.NewHandler(http.HandlerFunc(fn), "api") + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // there is a vulnerability with otelhttp where + // User-Agent strings are kept in RAM indefinitely and + // can be used as an easy way to resource exhaustion; + // so this code strips the User-Agent header before + // it's passed to be traced by otelhttp, and then is + // returned back to the middleware + // https://github.com/supabase/gotrue/security/dependabot/11 + userAgent := r.UserAgent() + if userAgent != "" { + r.Header.Set("X-Gotrue-Original-User-Agent", userAgent) + r.Header.Set("User-Agent", "stripped") + } + + otelHandler.ServeHTTP(w, r) + }) + } +} diff --git a/auth_v2.169.0/internal/observability/tracing.go b/auth_v2.169.0/internal/observability/tracing.go new file mode 100644 index 0000000..cc18471 --- /dev/null +++ b/auth_v2.169.0/internal/observability/tracing.go @@ -0,0 +1,130 @@ +package observability + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/utilities" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp" + "go.opentelemetry.io/otel/propagation" + sdkresource "go.opentelemetry.io/otel/sdk/resource" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/trace" +) + +func Tracer(name string, opts ...trace.TracerOption) trace.Tracer { + return otel.Tracer(name, opts...) +} + +func openTelemetryResource() *sdkresource.Resource { + environmentResource := sdkresource.Environment() + gotrueResource := sdkresource.NewSchemaless(attribute.String("gotrue.version", utilities.Version)) + + mergedResource, err := sdkresource.Merge(environmentResource, gotrueResource) + if err != nil { + logrus.WithError(err).Error("unable to merge OpenTelemetry environment and gotrue resources") + + return environmentResource + } + + return mergedResource +} + +func enableOpenTelemetryTracing(ctx context.Context, tc *conf.TracingConfig) error { + var ( + err error + traceExporter *otlptrace.Exporter + ) + + switch tc.ExporterProtocol { + case "grpc": + traceExporter, err = otlptracegrpc.New(ctx) + if err != nil { + return err + } + + case "http/protobuf": + traceExporter, err = otlptracehttp.New(ctx) + if err != nil { + return err + } + + default: // http/json for example + return fmt.Errorf("unsupported OpenTelemetry exporter protocol %q", tc.ExporterProtocol) + } + + traceProvider := sdktrace.NewTracerProvider( + sdktrace.WithBatcher(traceExporter), + sdktrace.WithResource(openTelemetryResource()), + ) + + otel.SetTracerProvider(traceProvider) + + // Register the W3C trace context and baggage propagators so data is + // propagated across services/processes + otel.SetTextMapPropagator( + propagation.NewCompositeTextMapPropagator( + propagation.TraceContext{}, + propagation.Baggage{}, + ), + ) + + cleanupWaitGroup.Add(1) + go func() { + defer cleanupWaitGroup.Done() + + <-ctx.Done() + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + + if err := traceExporter.Shutdown(shutdownCtx); err != nil { + logrus.WithError(err).Error("unable to shutdown OpenTelemetry trace exporter") + } + + if err := traceProvider.Shutdown(shutdownCtx); err != nil { + logrus.WithError(err).Error("unable to shutdown OpenTelemetry trace provider") + } + }() + + logrus.Info("OpenTelemetry trace exporter started") + + return nil +} + +var ( + tracingOnce sync.Once +) + +// ConfigureTracing sets up global tracing configuration for OpenTracing / +// OpenTelemetry. The context should be the global context. Cancelling this +// context will cancel tracing collection. +func ConfigureTracing(ctx context.Context, tc *conf.TracingConfig) error { + if ctx == nil { + panic("context must not be nil") + } + + var err error + + tracingOnce.Do(func() { + if tc.Enabled { + if tc.Exporter == conf.OpenTelemetryTracing { + if err = enableOpenTelemetryTracing(ctx, tc); err != nil { + logrus.WithError(err).Error("unable to start OTLP trace exporter") + } + + } + } + }) + + return err +} diff --git a/auth_v2.169.0/internal/ratelimit/burst.go b/auth_v2.169.0/internal/ratelimit/burst.go new file mode 100644 index 0000000..6ae0ef5 --- /dev/null +++ b/auth_v2.169.0/internal/ratelimit/burst.go @@ -0,0 +1,60 @@ +package ratelimit + +import ( + "time" + + "github.com/supabase/auth/internal/conf" + "golang.org/x/time/rate" +) + +const defaultOverTime = time.Hour + +// BurstLimiter wraps the golang.org/x/time/rate package. +type BurstLimiter struct { + rl *rate.Limiter +} + +// NewBurstLimiter returns a rate limiter configured using the given conf.Rate. +// +// The returned Limiter will be configured with a token bucket containing a +// single token, which will fill up at a rate of 1 event per r.OverTime with +// an initial burst amount of r.Events. +// +// For example: +// - 1/10s is 1 events per 10 seconds with burst of 1. +// - 1/2s is 1 events per 2 seconds with burst of 1. +// - 10/10s is 1 events per 10 seconds with burst of 10. +// +// If Rate.Events is <= 0, the burst amount will be set to 1. +// +// See Example_newBurstLimiter for a visualization. +func NewBurstLimiter(r conf.Rate) *BurstLimiter { + // The rate limiter deals in events per second. + d := r.OverTime + if d <= 0 { + d = defaultOverTime + } + + e := r.Events + if e <= 0 { + e = 0 + } + + // BurstLimiter will have an initial token bucket of size `e`. It will + // be refilled at a rate of 1 per duration `d` indefinitely. + rl := &BurstLimiter{ + rl: rate.NewLimiter(rate.Every(d), int(e)), + } + return rl +} + +// Allow implements Limiter by calling AllowAt with the current time. +func (l *BurstLimiter) Allow() bool { + return l.AllowAt(time.Now()) +} + +// AllowAt implements Limiter by calling the underlying x/time/rate.Limiter +// with the given time. +func (l *BurstLimiter) AllowAt(at time.Time) bool { + return l.rl.AllowN(at, 1) +} diff --git a/auth_v2.169.0/internal/ratelimit/burst_test.go b/auth_v2.169.0/internal/ratelimit/burst_test.go new file mode 100644 index 0000000..b854e3b --- /dev/null +++ b/auth_v2.169.0/internal/ratelimit/burst_test.go @@ -0,0 +1,214 @@ +package ratelimit + +import ( + "fmt" + "testing" + "time" + + "github.com/supabase/auth/internal/conf" +) + +func Example_newBurstLimiter() { + now, _ := time.Parse(time.RFC3339, "2024-09-24T10:00:00.00Z") + { + cfg := conf.Rate{Events: 10, OverTime: time.Second * 20} + rl := NewBurstLimiter(cfg) + cur := now + for i := 0; i < 20; i++ { + allowed := rl.AllowAt(cur) + fmt.Printf("%-5v @ %v\n", allowed, cur) + cur = cur.Add(time.Second * 5) + } + } + + // Output: + // true @ 2024-09-24 10:00:00 +0000 UTC + // true @ 2024-09-24 10:00:05 +0000 UTC + // true @ 2024-09-24 10:00:10 +0000 UTC + // true @ 2024-09-24 10:00:15 +0000 UTC + // true @ 2024-09-24 10:00:20 +0000 UTC + // true @ 2024-09-24 10:00:25 +0000 UTC + // true @ 2024-09-24 10:00:30 +0000 UTC + // true @ 2024-09-24 10:00:35 +0000 UTC + // true @ 2024-09-24 10:00:40 +0000 UTC + // true @ 2024-09-24 10:00:45 +0000 UTC + // true @ 2024-09-24 10:00:50 +0000 UTC + // true @ 2024-09-24 10:00:55 +0000 UTC + // true @ 2024-09-24 10:01:00 +0000 UTC + // false @ 2024-09-24 10:01:05 +0000 UTC + // false @ 2024-09-24 10:01:10 +0000 UTC + // false @ 2024-09-24 10:01:15 +0000 UTC + // true @ 2024-09-24 10:01:20 +0000 UTC + // false @ 2024-09-24 10:01:25 +0000 UTC + // false @ 2024-09-24 10:01:30 +0000 UTC + // false @ 2024-09-24 10:01:35 +0000 UTC +} + +func TestBurstLimiter(t *testing.T) { + t.Run("Allow", func(t *testing.T) { + for i := 1; i < 10; i++ { + cfg := conf.Rate{Events: float64(i), OverTime: time.Hour} + rl := NewBurstLimiter(cfg) + for y := i; y > 0; y-- { + if exp, got := true, rl.Allow(); exp != got { + t.Fatalf("exp Allow() to be %v; got %v", exp, got) + } + } + if exp, got := false, rl.Allow(); exp != got { + t.Fatalf("exp Allow() to be %v; got %v", exp, got) + } + } + }) + + t.Run("AllowAt", func(t *testing.T) { + now, _ := time.Parse(time.RFC3339, "2024-09-24T10:00:00.00Z") + + type event struct { + ok bool + at time.Time + + // Event should be `ok` at `at` for `i` times + i int + } + + type testCase struct { + cfg conf.Rate + now time.Time + evts []event + } + cases := []testCase{ + { + cfg: conf.Rate{Events: 20, OverTime: time.Second * 20}, + now: now, + evts: []event{ + // initial burst of 20 is permitted + {true, now, 19}, + + // then denied, even at same time + {false, now, 100}, + + // and continue to deny until the next generated token + {false, now.Add(time.Second), 100}, + {false, now.Add(time.Second * 19), 100}, + + // allows a single call to allow at 20 seconds + {true, now.Add(time.Second * 20), 0}, + + // then denied + {false, now.Add(time.Second * 20), 100}, + + // and the pattern repeats + {true, now.Add(time.Second * 40), 0}, + {false, now.Add(time.Second * 40), 100}, + {false, now.Add(time.Second * 59), 100}, + + {true, now.Add(time.Second * 60), 0}, + {false, now.Add(time.Second * 60), 100}, + {false, now.Add(time.Second * 79), 100}, + + {true, now.Add(time.Second * 80), 0}, + {false, now.Add(time.Second * 80), 100}, + {false, now.Add(time.Second * 99), 100}, + + // allow tokens to be built up still + {true, now.Add(time.Hour), 19}, + }, + }, + + { + cfg: conf.Rate{Events: 1, OverTime: time.Second * 20}, + now: now, + evts: []event{ + // initial burst of 1 is permitted + {true, now, 0}, + + // then denied, even at same time + {false, now, 100}, + + // and continue to deny until the next generated token + {false, now.Add(time.Second), 100}, + {false, now.Add(time.Second * 19), 100}, + + // allows a single call to allow at 20 seconds + {true, now.Add(time.Second * 20), 0}, + + // then denied + {false, now.Add(time.Second * 20), 100}, + + // and the pattern repeats + {true, now.Add(time.Second * 40), 0}, + {false, now.Add(time.Second * 40), 100}, + {false, now.Add(time.Second * 59), 100}, + + {true, now.Add(time.Second * 60), 0}, + {false, now.Add(time.Second * 60), 100}, + {false, now.Add(time.Second * 79), 100}, + + {true, now.Add(time.Second * 80), 0}, + {false, now.Add(time.Second * 80), 100}, + {false, now.Add(time.Second * 99), 100}, + }, + }, + + // 1 event per second + { + cfg: conf.Rate{Events: 1, OverTime: time.Second}, + now: now, + evts: []event{ + {true, now, 0}, + {true, now.Add(time.Second), 0}, + {false, now.Add(time.Second), 0}, + {true, now.Add(time.Second * 2), 0}, + }, + }, + + // 1 events per second and OverTime = 1 event per hour. + { + cfg: conf.Rate{Events: 1, OverTime: 0}, + now: now, + evts: []event{ + {true, now, 0}, + {false, now.Add(time.Hour - time.Second), 0}, + {true, now.Add(time.Hour), 0}, + {true, now.Add(time.Hour * 2), 0}, + }, + }, + + // zero value for Events = 0 event per second + { + cfg: conf.Rate{Events: 0, OverTime: time.Second}, + now: now, + evts: []event{ + {false, now, 0}, + {false, now.Add(-time.Second), 0}, + {false, now.Add(time.Second), 0}, + {false, now.Add(time.Second * 2), 0}, + }, + }, + + // zero value for both Events and OverTime = 1 event per hour. + { + cfg: conf.Rate{Events: 0, OverTime: 0}, + now: now, + evts: []event{ + {false, now, 0}, + {false, now.Add(time.Hour - time.Second), 0}, + {false, now.Add(-time.Hour), 0}, + {false, now.Add(time.Hour), 0}, + {false, now.Add(time.Hour * 2), 0}, + }, + }, + } + + for _, tc := range cases { + rl := NewBurstLimiter(tc.cfg) + for _, evt := range tc.evts { + for i := 0; i <= evt.i; i++ { + if exp, got := evt.ok, rl.AllowAt(evt.at); exp != got { + t.Fatalf("exp AllowAt(%v) to be %v; got %v", evt.at, exp, got) + } + } + } + } + }) +} diff --git a/auth_v2.169.0/internal/ratelimit/interval.go b/auth_v2.169.0/internal/ratelimit/interval.go new file mode 100644 index 0000000..a72302f --- /dev/null +++ b/auth_v2.169.0/internal/ratelimit/interval.go @@ -0,0 +1,63 @@ +package ratelimit + +import ( + "sync" + "time" + + "github.com/supabase/auth/internal/conf" +) + +// IntervalLimiter will limit the number of calls to Allow per interval. +type IntervalLimiter struct { + mu sync.Mutex + ival time.Duration // Count is reset and time updated every ival. + limit int // Limit calls to Allow() per ival. + + // Guarded by mu. + last time.Time // When the limiter was last reset. + count int // Total calls to Allow() since time. +} + +// NewIntervalLimiter returns a rate limiter using the given conf.Rate. +func NewIntervalLimiter(r conf.Rate) *IntervalLimiter { + return &IntervalLimiter{ + ival: r.OverTime, + limit: int(r.Events), + last: time.Now(), + } +} + +// Allow implements Limiter by calling AllowAt with the current time. +func (rl *IntervalLimiter) Allow() bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + return rl.allowAt(time.Now()) +} + +// AllowAt implements Limiter by checking if the current number of permitted +// events within this interval would permit 1 additional event at the current +// time. +// +// When called with a time outside the current active interval the counter is +// reset, meaning it can be vulnerable at the edge of it's intervals so avoid +// small intervals. +func (rl *IntervalLimiter) AllowAt(at time.Time) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + return rl.allowAt(at) +} + +func (rl *IntervalLimiter) allowAt(at time.Time) bool { + since := at.Sub(rl.last) + if ivals := int64(since / rl.ival); ivals > 0 { + rl.last = rl.last.Add(time.Duration(ivals) * rl.ival) + rl.count = 0 + } + if rl.count < rl.limit { + rl.count++ + return true + } + return false +} diff --git a/auth_v2.169.0/internal/ratelimit/interval_test.go b/auth_v2.169.0/internal/ratelimit/interval_test.go new file mode 100644 index 0000000..835ee82 --- /dev/null +++ b/auth_v2.169.0/internal/ratelimit/interval_test.go @@ -0,0 +1,81 @@ +package ratelimit + +import ( + "fmt" + "testing" + "time" + + "github.com/supabase/auth/internal/conf" +) + +func Example_newIntervalLimiter() { + now, _ := time.Parse(time.RFC3339, "2024-09-24T10:00:00.00Z") + cfg := conf.Rate{Events: 100, OverTime: time.Hour * 24} + rl := NewIntervalLimiter(cfg) + rl.last = now + + cur := now + allowed := 0 + + for days := 0; days < 2; days++ { + // First 100 events succeed. + for i := 0; i < 100; i++ { + allow := rl.allowAt(cur) + cur = cur.Add(time.Second) + + if !allow { + fmt.Printf("false @ %v after %v events... [FAILED]\n", cur, allowed) + return + } + allowed++ + } + fmt.Printf("true @ %v for last %v events...\n", cur, allowed) + + // We try hourly until it allows us to make requests again. + denied := 0 + for i := 0; i < 23; i++ { + cur = cur.Add(time.Hour) + allow := rl.AllowAt(cur) + if allow { + fmt.Printf("true @ %v before quota reset... [FAILED]\n", cur) + return + } + denied++ + } + fmt.Printf("false @ %v for last %v events...\n", cur, denied) + + cur = cur.Add(time.Hour) + } + + // Output: + // true @ 2024-09-24 10:01:40 +0000 UTC for last 100 events... + // false @ 2024-09-25 09:01:40 +0000 UTC for last 23 events... + // true @ 2024-09-25 10:03:20 +0000 UTC for last 200 events... + // false @ 2024-09-26 09:03:20 +0000 UTC for last 23 events... +} + +func TestNewIntervalLimiter(t *testing.T) { + t.Run("Allow", func(t *testing.T) { + for i := 1; i < 10; i++ { + cfg := conf.Rate{Events: float64(i), OverTime: time.Hour} + rl := NewIntervalLimiter(cfg) + for y := i; y > 0; y-- { + if exp, got := true, rl.Allow(); exp != got { + t.Fatalf("exp Allow() to be %v; got %v", exp, got) + } + } + if exp, got := false, rl.Allow(); exp != got { + t.Fatalf("exp Allow() to be %v; got %v", exp, got) + } + } + + // should accept a negative burst. + cfg := conf.Rate{Events: 10, OverTime: time.Hour} + rl := NewBurstLimiter(cfg) + for y := 0; y < 10; y++ { + if exp, got := true, rl.Allow(); exp != got { + t.Fatalf("exp Allow() to be %v; got %v", exp, got) + } + } + }) +} diff --git a/auth_v2.169.0/internal/ratelimit/ratelimit.go b/auth_v2.169.0/internal/ratelimit/ratelimit.go new file mode 100644 index 0000000..35fbf9b --- /dev/null +++ b/auth_v2.169.0/internal/ratelimit/ratelimit.go @@ -0,0 +1,34 @@ +package ratelimit + +import ( + "time" + + "github.com/supabase/auth/internal/conf" +) + +// Limiter is the interface implemented by rate limiters. +// +// Implementations of Limiter must be safe for concurrent use. +type Limiter interface { + + // Allow should return true if an event should be allowed at the time + // which it was called, or false otherwise. + Allow() bool + + // AllowAt should return true if an event should be allowed at the given + // time, or false otherwise. + AllowAt(at time.Time) bool +} + +// New returns a new Limiter based on the given config. +// +// When the type is conf.BurstRateType it returns a BurstLimiter, otherwise +// New returns an IntervalLimiter. +func New(r conf.Rate) Limiter { + switch r.GetRateType() { + case conf.BurstRateType: + return NewBurstLimiter(r) + default: + return NewIntervalLimiter(r) + } +} diff --git a/auth_v2.169.0/internal/ratelimit/ratelimit_test.go b/auth_v2.169.0/internal/ratelimit/ratelimit_test.go new file mode 100644 index 0000000..3bac1dc --- /dev/null +++ b/auth_v2.169.0/internal/ratelimit/ratelimit_test.go @@ -0,0 +1,50 @@ +package ratelimit + +import ( + "testing" + + "github.com/supabase/auth/internal/conf" +) + +func TestNew(t *testing.T) { + + // IntervalLimiter + { + var r conf.Rate + err := r.Decode("100") + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + + rl := New(r) + if _, ok := rl.(*IntervalLimiter); !ok { + t.Fatalf("exp type *IntervalLimiter; got %T", rl) + } + } + { + var r conf.Rate + err := r.Decode("100.123") + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + + rl := New(r) + if _, ok := rl.(*IntervalLimiter); !ok { + t.Fatalf("exp type *IntervalLimiter; got %T", rl) + } + } + + // BurstLimiter + { + var r conf.Rate + err := r.Decode("20/200s") + if err != nil { + t.Fatalf("exp nil err; got %v", err) + } + + rl := New(r) + if _, ok := rl.(*BurstLimiter); !ok { + t.Fatalf("exp type *BurstLimiter; got %T", rl) + } + } +} diff --git a/auth_v2.169.0/internal/reloader/handler.go b/auth_v2.169.0/internal/reloader/handler.go new file mode 100644 index 0000000..bdd15ca --- /dev/null +++ b/auth_v2.169.0/internal/reloader/handler.go @@ -0,0 +1,42 @@ +package reloader + +import ( + "net/http" + "sync/atomic" +) + +// AtomicHandler provides an atomic http.Handler implementation, allowing safe +// handler replacement at runtime. AtomicHandler must be initialized with a call +// to NewAtomicHandler. It will never panic and is safe for concurrent use. +type AtomicHandler struct { + val atomic.Value +} + +// atomicHandlerValue is the value stored within an atomicHandler. +type atomicHandlerValue struct{ http.Handler } + +// NewAtomicHandler creates a new AtomicHandler ready for use. +func NewAtomicHandler(h http.Handler) *AtomicHandler { + ah := new(AtomicHandler) + ah.Store(h) + return ah +} + +// String implements fmt.Stringer by returning a string literal. +func (ah *AtomicHandler) String() string { return "reloader.AtomicHandler" } + +// Store will update this http.Handler to serve future requests using h. +func (ah *AtomicHandler) Store(h http.Handler) { + ah.val.Store(&atomicHandlerValue{h}) +} + +// load will return the underlying http.Handler used to serve requests. +func (ah *AtomicHandler) load() http.Handler { + return ah.val.Load().(*atomicHandlerValue).Handler +} + +// ServeHTTP implements the standard libraries http.Handler interface by +// atomically passing the request along to the most recently stored handler. +func (ah *AtomicHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ah.load().ServeHTTP(w, r) +} diff --git a/auth_v2.169.0/internal/reloader/handler_race_test.go b/auth_v2.169.0/internal/reloader/handler_race_test.go new file mode 100644 index 0000000..4d7b5e0 --- /dev/null +++ b/auth_v2.169.0/internal/reloader/handler_race_test.go @@ -0,0 +1,64 @@ +//go:build race +// +build race + +package reloader + +import ( + "context" + "net/http" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestAtomicHandlerRaces(t *testing.T) { + type testHandler struct{ http.Handler } + + hrFn := func() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + } + + const count = 8 + hrFuncMap := make(map[http.Handler]struct{}, count) + for i := 0; i < count; i++ { + hrFuncMap[&testHandler{hrFn()}] = struct{}{} + } + + hr := NewAtomicHandler(nil) + assert.NotNil(t, hr) + + var wg sync.WaitGroup + defer wg.Wait() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second/4) + defer cancel() + + // We create 8 goroutines reading & writing to the handler concurrently. If + // a race condition occurs the test will fail and halt. + for i := 0; i < count; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + for hrFunc := range hrFuncMap { + select { + case <-ctx.Done(): + default: + } + + hr.Store(hrFunc) + + got := hr.load() + _, ok := hrFuncMap[got] + if !ok { + // This will trigger a race failure / exit test + t.Fatal("unknown handler returned from load()") + return + } + } + }() + } + wg.Wait() +} diff --git a/auth_v2.169.0/internal/reloader/handler_test.go b/auth_v2.169.0/internal/reloader/handler_test.go new file mode 100644 index 0000000..182c526 --- /dev/null +++ b/auth_v2.169.0/internal/reloader/handler_test.go @@ -0,0 +1,46 @@ +package reloader + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAtomicHandler(t *testing.T) { + // for ptr identity + type testHandler struct{ http.Handler } + + hrFn := func() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + } + + hrFunc1 := &testHandler{hrFn()} + hrFunc2 := &testHandler{hrFn()} + assert.NotEqual(t, hrFunc1, hrFunc2) + + // a new AtomicHandler should be non-nil + hr := NewAtomicHandler(nil) + assert.NotNil(t, hr) + + // should have no stored handler + { + hrCur := hr.load() + assert.Nil(t, hrCur) + assert.Equal(t, true, hrCur == nil) + } + + // should be non-nil after store + for i := 0; i < 3; i++ { + hr.Store(hrFunc1) + assert.NotNil(t, hr.load()) + assert.Equal(t, hr.load(), hrFunc1) + assert.Equal(t, hr.load() == hrFunc1, true) + + // should update to hrFunc2 + hr.Store(hrFunc2) + assert.NotNil(t, hr.load()) + assert.Equal(t, hr.load(), hrFunc2) + assert.Equal(t, hr.load() == hrFunc2, true) + } +} diff --git a/auth_v2.169.0/internal/reloader/reloader.go b/auth_v2.169.0/internal/reloader/reloader.go new file mode 100644 index 0000000..2b2b55e --- /dev/null +++ b/auth_v2.169.0/internal/reloader/reloader.go @@ -0,0 +1,141 @@ +// Package reloader provides support for live configuration reloading. +package reloader + +import ( + "context" + "log" + "strings" + "time" + + "github.com/fsnotify/fsnotify" + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" +) + +const ( + // reloadInterval is the interval between configuration reloading. At most + // one configuration change may be made between this duration. + reloadInterval = time.Second * 10 + + // tickerInterval is the maximum latency between configuration reloads. + tickerInterval = reloadInterval / 10 +) + +type ConfigFunc func(*conf.GlobalConfiguration) + +type Reloader struct { + watchDir string + reloadIval time.Duration + tickerIval time.Duration +} + +func NewReloader(watchDir string) *Reloader { + return &Reloader{ + watchDir: watchDir, + reloadIval: reloadInterval, + tickerIval: tickerInterval, + } +} + +// reload attempts to create a new *conf.GlobalConfiguration after loading the +// currently configured watchDir. +func (rl *Reloader) reload() (*conf.GlobalConfiguration, error) { + if err := conf.LoadDirectory(rl.watchDir); err != nil { + return nil, err + } + + cfg, err := conf.LoadGlobalFromEnv() + if err != nil { + return nil, err + } + return cfg, nil +} + +// reloadCheckAt checks if reloadConfig should be called, returns true if config +// should be reloaded or false otherwise. +func (rl *Reloader) reloadCheckAt(at, lastUpdate time.Time) bool { + if lastUpdate.IsZero() { + return false // no pending updates + } + if at.Sub(lastUpdate) < rl.reloadIval { + return false // waiting for reload interval + } + + // Update is pending. + return true +} + +func (rl *Reloader) Watch(ctx context.Context, fn ConfigFunc) error { + wr, err := fsnotify.NewWatcher() + if err != nil { + log.Fatal(err) + } + defer wr.Close() + + tr := time.NewTicker(rl.tickerIval) + defer tr.Stop() + + // Ignore errors, if watch dir doesn't exist we can add it later. + if err := wr.Add(rl.watchDir); err != nil { + logrus.WithError(err).Error("watch dir failed") + } + + var lastUpdate time.Time + for { + select { + case <-ctx.Done(): + return ctx.Err() + + case <-tr.C: + // This is a simple way to solve watch dir being added later or + // being moved and then recreated. I've tested all of these basic + // scenarios and wr.WatchList() does not grow which aligns with + // the documented behavior. + if err := wr.Add(rl.watchDir); err != nil { + logrus.WithError(err).Error("watch dir failed") + } + + // Check to see if the config is ready to be relaoded. + if !rl.reloadCheckAt(time.Now(), lastUpdate) { + continue + } + + // Reset the last update time before we try to reload the config. + lastUpdate = time.Time{} + + cfg, err := rl.reload() + if err != nil { + logrus.WithError(err).Error("config reload failed") + continue + } + + // Call the callback function with the latest cfg. + fn(cfg) + + case evt, ok := <-wr.Events: + if !ok { + logrus.WithError(err).Error("fsnotify has exited") + return nil + } + + // We only read files ending in .env + if !strings.HasSuffix(evt.Name, ".env") { + continue + } + + switch { + case evt.Op.Has(fsnotify.Create), + evt.Op.Has(fsnotify.Remove), + evt.Op.Has(fsnotify.Rename), + evt.Op.Has(fsnotify.Write): + lastUpdate = time.Now() + } + case err, ok := <-wr.Errors: + if !ok { + logrus.Error("fsnotify has exited") + return nil + } + logrus.WithError(err).Error("fsnotify has reported an error") + } + } +} diff --git a/auth_v2.169.0/internal/reloader/reloader_test.go b/auth_v2.169.0/internal/reloader/reloader_test.go new file mode 100644 index 0000000..ec8e04b --- /dev/null +++ b/auth_v2.169.0/internal/reloader/reloader_test.go @@ -0,0 +1,173 @@ +package reloader + +import ( + "bytes" + "log" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestReloadConfig(t *testing.T) { + dir, cleanup := helpTestDir(t) + defer cleanup() + + rl := NewReloader(dir) + + // Copy the full and valid example configuration. + helpCopyEnvFile(t, dir, "01_example.env", "testdata/50_example.env") + { + cfg, err := rl.reload() + if err != nil { + t.Fatal(err) + } + assert.NotNil(t, cfg) + assert.Equal(t, cfg.External.Apple.Enabled, false) + } + + helpWriteEnvFile(t, dir, "02_example.env", map[string]string{ + "GOTRUE_EXTERNAL_APPLE_ENABLED": "true", + }) + { + cfg, err := rl.reload() + if err != nil { + t.Fatal(err) + } + assert.NotNil(t, cfg) + assert.Equal(t, cfg.External.Apple.Enabled, true) + } + + helpWriteEnvFile(t, dir, "03_example.env.bak", map[string]string{ + "GOTRUE_EXTERNAL_APPLE_ENABLED": "false", + }) + { + cfg, err := rl.reload() + if err != nil { + t.Fatal(err) + } + assert.NotNil(t, cfg) + assert.Equal(t, cfg.External.Apple.Enabled, true) + } +} + +func TestReloadCheckAt(t *testing.T) { + const s10 = time.Second * 10 + + now := time.Now() + tests := []struct { + rl *Reloader + at, lastUpdate time.Time + exp bool + }{ + // no lastUpdate is set (time.IsZero()) + { + rl: &Reloader{reloadIval: s10, tickerIval: s10 / 10}, + exp: false, + }, + { + rl: &Reloader{reloadIval: s10, tickerIval: s10 / 10}, + at: now, + exp: false, + }, + + // last update within reload interval + { + rl: &Reloader{reloadIval: s10, tickerIval: s10 / 10}, + at: now, + lastUpdate: now.Add(-s10 + 1), + exp: false, + }, + { + rl: &Reloader{reloadIval: s10, tickerIval: s10 / 10}, + at: now, + lastUpdate: now, + exp: false, + }, + { + rl: &Reloader{reloadIval: s10, tickerIval: s10 / 10}, + at: now, + lastUpdate: now.Add(s10 - 1), + exp: false, + }, + { + rl: &Reloader{reloadIval: s10, tickerIval: s10 / 10}, + at: now, + lastUpdate: now.Add(s10), + exp: false, + }, + { + rl: &Reloader{reloadIval: s10, tickerIval: s10 / 10}, + at: now, + lastUpdate: now.Add(s10 + 1), + exp: false, + }, + { + rl: &Reloader{reloadIval: s10, tickerIval: s10 / 10}, + at: now, + lastUpdate: now.Add(s10 * 2), + exp: false, + }, + + // last update was outside our reload interval + { + rl: &Reloader{reloadIval: s10, tickerIval: s10 / 10}, + at: now, + lastUpdate: now.Add(-s10), + exp: true, + }, + { + rl: &Reloader{reloadIval: s10, tickerIval: s10 / 10}, + at: now, + lastUpdate: now.Add(-s10 - 1), + exp: true, + }, + } + for _, tc := range tests { + rl := tc.rl + assert.NotNil(t, rl) + assert.Equal(t, rl.reloadCheckAt(tc.at, tc.lastUpdate), tc.exp) + } +} + +func helpTestDir(t testing.TB) (dir string, cleanup func()) { + dir = filepath.Join("testdata", t.Name()) + err := os.MkdirAll(dir, 0750) + if err != nil && !os.IsExist(err) { + t.Fatal(err) + } + return dir, func() { os.RemoveAll(dir) } +} + +func helpCopyEnvFile(t testing.TB, dir, name, src string) string { + data, err := os.ReadFile(src) // #nosec G304 + if err != nil { + log.Fatal(err) + } + + dst := filepath.Join(dir, name) + err = os.WriteFile(dst, data, 0600) + if err != nil { + t.Fatal(err) + } + return dst +} + +func helpWriteEnvFile(t testing.TB, dir, name string, values map[string]string) string { + var buf bytes.Buffer + for k, v := range values { + buf.WriteString(k) + buf.WriteString("=") + buf.WriteString(v) + buf.WriteString("\n") + } + + dst := filepath.Join(dir, name) + err := os.WriteFile(dst, buf.Bytes(), 0600) + if err != nil { + t.Fatal(err) + } + return dst +} diff --git a/auth_v2.169.0/internal/reloader/testdata/50_example.env b/auth_v2.169.0/internal/reloader/testdata/50_example.env new file mode 100644 index 0000000..1002d8b --- /dev/null +++ b/auth_v2.169.0/internal/reloader/testdata/50_example.env @@ -0,0 +1,235 @@ +# General Config +# NOTE: The service_role key is required as an authorization header for /admin endpoints + +GOTRUE_JWT_SECRET="CHANGE-THIS! VERY IMPORTANT!" +GOTRUE_JWT_EXP="3600" +GOTRUE_JWT_AUD="authenticated" +GOTRUE_JWT_DEFAULT_GROUP_NAME="authenticated" +GOTRUE_JWT_ADMIN_ROLES="supabase_admin,service_role" + +# Database & API connection details +GOTRUE_DB_DRIVER="postgres" +DB_NAMESPACE="auth" +DATABASE_URL="postgres://supabase_auth_admin:root@localhost:5432/postgres" +API_EXTERNAL_URL="http://localhost:9999" +GOTRUE_API_HOST="localhost" +PORT="9999" + +# SMTP config (generate credentials for signup to work) +GOTRUE_SMTP_HOST="" +GOTRUE_SMTP_PORT="587" +GOTRUE_SMTP_USER="" +GOTRUE_SMTP_MAX_FREQUENCY="5s" +GOTRUE_SMTP_PASS="" +GOTRUE_SMTP_ADMIN_EMAIL="" +GOTRUE_SMTP_SENDER_NAME="" + +# Mailer config +GOTRUE_MAILER_AUTOCONFIRM="true" +GOTRUE_MAILER_URLPATHS_CONFIRMATION="/verify" +GOTRUE_MAILER_URLPATHS_INVITE="/verify" +GOTRUE_MAILER_URLPATHS_RECOVERY="/verify" +GOTRUE_MAILER_URLPATHS_EMAIL_CHANGE="/verify" +GOTRUE_MAILER_SUBJECTS_CONFIRMATION="Confirm Your Email" +GOTRUE_MAILER_SUBJECTS_RECOVERY="Reset Your Password" +GOTRUE_MAILER_SUBJECTS_MAGIC_LINK="Your Magic Link" +GOTRUE_MAILER_SUBJECTS_EMAIL_CHANGE="Confirm Email Change" +GOTRUE_MAILER_SUBJECTS_INVITE="You have been invited" +GOTRUE_MAILER_SECURE_EMAIL_CHANGE_ENABLED="true" + +# Custom mailer template config +GOTRUE_MAILER_TEMPLATES_INVITE="" +GOTRUE_MAILER_TEMPLATES_CONFIRMATION="" +GOTRUE_MAILER_TEMPLATES_RECOVERY="" +GOTRUE_MAILER_TEMPLATES_MAGIC_LINK="" +GOTRUE_MAILER_TEMPLATES_EMAIL_CHANGE="" + +# Signup config +GOTRUE_DISABLE_SIGNUP="false" +GOTRUE_SITE_URL="http://localhost:3000" +GOTRUE_EXTERNAL_EMAIL_ENABLED="true" +GOTRUE_EXTERNAL_PHONE_ENABLED="true" +GOTRUE_EXTERNAL_IOS_BUNDLE_ID="com.supabase.auth" + +# Whitelist redirect to URLs here, a comma separated list of URIs (e.g. "https://foo.example.com,https://*.foo.example.com,https://bar.example.com") +GOTRUE_URI_ALLOW_LIST="http://localhost:3000" + +# Apple OAuth config +GOTRUE_EXTERNAL_APPLE_ENABLED="false" +GOTRUE_EXTERNAL_APPLE_CLIENT_ID="" +GOTRUE_EXTERNAL_APPLE_SECRET="" +GOTRUE_EXTERNAL_APPLE_REDIRECT_URI="http://localhost:9999/callback" + +# Azure OAuth config +GOTRUE_EXTERNAL_AZURE_ENABLED="false" +GOTRUE_EXTERNAL_AZURE_CLIENT_ID="" +GOTRUE_EXTERNAL_AZURE_SECRET="" +GOTRUE_EXTERNAL_AZURE_REDIRECT_URI="https://localhost:9999/callback" + +# Bitbucket OAuth config +GOTRUE_EXTERNAL_BITBUCKET_ENABLED="false" +GOTRUE_EXTERNAL_BITBUCKET_CLIENT_ID="" +GOTRUE_EXTERNAL_BITBUCKET_SECRET="" +GOTRUE_EXTERNAL_BITBUCKET_REDIRECT_URI="http://localhost:9999/callback" + +# Discord OAuth config +GOTRUE_EXTERNAL_DISCORD_ENABLED="false" +GOTRUE_EXTERNAL_DISCORD_CLIENT_ID="" +GOTRUE_EXTERNAL_DISCORD_SECRET="" +GOTRUE_EXTERNAL_DISCORD_REDIRECT_URI="https://localhost:9999/callback" + +# Facebook OAuth config +GOTRUE_EXTERNAL_FACEBOOK_ENABLED="false" +GOTRUE_EXTERNAL_FACEBOOK_CLIENT_ID="" +GOTRUE_EXTERNAL_FACEBOOK_SECRET="" +GOTRUE_EXTERNAL_FACEBOOK_REDIRECT_URI="https://localhost:9999/callback" + +# Figma OAuth config +GOTRUE_EXTERNAL_FIGMA_ENABLED="false" +GOTRUE_EXTERNAL_FIGMA_CLIENT_ID="" +GOTRUE_EXTERNAL_FIGMA_SECRET="" +GOTRUE_EXTERNAL_FIGMA_REDIRECT_URI="https://localhost:9999/callback" + +# Gitlab OAuth config +GOTRUE_EXTERNAL_GITLAB_ENABLED="false" +GOTRUE_EXTERNAL_GITLAB_CLIENT_ID="" +GOTRUE_EXTERNAL_GITLAB_SECRET="" +GOTRUE_EXTERNAL_GITLAB_REDIRECT_URI="http://localhost:9999/callback" + +# Google OAuth config +GOTRUE_EXTERNAL_GOOGLE_ENABLED="false" +GOTRUE_EXTERNAL_GOOGLE_CLIENT_ID="" +GOTRUE_EXTERNAL_GOOGLE_SECRET="" +GOTRUE_EXTERNAL_GOOGLE_REDIRECT_URI="http://localhost:9999/callback" + +# Github OAuth config +GOTRUE_EXTERNAL_GITHUB_ENABLED="false" +GOTRUE_EXTERNAL_GITHUB_CLIENT_ID="" +GOTRUE_EXTERNAL_GITHUB_SECRET="" +GOTRUE_EXTERNAL_GITHUB_REDIRECT_URI="http://localhost:9999/callback" + +# Kakao OAuth config +GOTRUE_EXTERNAL_KAKAO_ENABLED="false" +GOTRUE_EXTERNAL_KAKAO_CLIENT_ID="" +GOTRUE_EXTERNAL_KAKAO_SECRET="" +GOTRUE_EXTERNAL_KAKAO_REDIRECT_URI="http://localhost:9999/callback" + +# Notion OAuth config +GOTRUE_EXTERNAL_NOTION_ENABLED="false" +GOTRUE_EXTERNAL_NOTION_CLIENT_ID="" +GOTRUE_EXTERNAL_NOTION_SECRET="" +GOTRUE_EXTERNAL_NOTION_REDIRECT_URI="https://localhost:9999/callback" + +# Twitter OAuth1 config +GOTRUE_EXTERNAL_TWITTER_ENABLED="false" +GOTRUE_EXTERNAL_TWITTER_CLIENT_ID="" +GOTRUE_EXTERNAL_TWITTER_SECRET="" +GOTRUE_EXTERNAL_TWITTER_REDIRECT_URI="http://localhost:9999/callback" + +# Twitch OAuth config +GOTRUE_EXTERNAL_TWITCH_ENABLED="false" +GOTRUE_EXTERNAL_TWITCH_CLIENT_ID="" +GOTRUE_EXTERNAL_TWITCH_SECRET="" +GOTRUE_EXTERNAL_TWITCH_REDIRECT_URI="http://localhost:9999/callback" + +# Spotify OAuth config +GOTRUE_EXTERNAL_SPOTIFY_ENABLED="false" +GOTRUE_EXTERNAL_SPOTIFY_CLIENT_ID="" +GOTRUE_EXTERNAL_SPOTIFY_SECRET="" +GOTRUE_EXTERNAL_SPOTIFY_REDIRECT_URI="http://localhost:9999/callback" + +# Keycloak OAuth config +GOTRUE_EXTERNAL_KEYCLOAK_ENABLED="false" +GOTRUE_EXTERNAL_KEYCLOAK_CLIENT_ID="" +GOTRUE_EXTERNAL_KEYCLOAK_SECRET="" +GOTRUE_EXTERNAL_KEYCLOAK_REDIRECT_URI="http://localhost:9999/callback" +GOTRUE_EXTERNAL_KEYCLOAK_URL="https://keycloak.example.com/auth/realms/myrealm" + +# Linkedin OAuth config +GOTRUE_EXTERNAL_LINKEDIN_ENABLED="true" +GOTRUE_EXTERNAL_LINKEDIN_CLIENT_ID="" +GOTRUE_EXTERNAL_LINKEDIN_SECRET="" + +# Slack OAuth config +GOTRUE_EXTERNAL_SLACK_ENABLED="false" +GOTRUE_EXTERNAL_SLACK_CLIENT_ID="" +GOTRUE_EXTERNAL_SLACK_SECRET="" +GOTRUE_EXTERNAL_SLACK_REDIRECT_URI="http://localhost:9999/callback" + +# WorkOS OAuth config +GOTRUE_EXTERNAL_WORKOS_ENABLED="true" +GOTRUE_EXTERNAL_WORKOS_CLIENT_ID="" +GOTRUE_EXTERNAL_WORKOS_SECRET="" +GOTRUE_EXTERNAL_WORKOS_REDIRECT_URI="http://localhost:9999/callback" + +# Zoom OAuth config +GOTRUE_EXTERNAL_ZOOM_ENABLED="false" +GOTRUE_EXTERNAL_ZOOM_CLIENT_ID="" +GOTRUE_EXTERNAL_ZOOM_SECRET="" +GOTRUE_EXTERNAL_ZOOM_REDIRECT_URI="http://localhost:9999/callback" + +# Anonymous auth config +GOTRUE_EXTERNAL_ANONYMOUS_USERS_ENABLED="false" + +# PKCE Config +GOTRUE_EXTERNAL_FLOW_STATE_EXPIRY_DURATION="300s" + +# Phone provider config +GOTRUE_SMS_AUTOCONFIRM="false" +GOTRUE_SMS_MAX_FREQUENCY="5s" +GOTRUE_SMS_OTP_EXP="6000" +GOTRUE_SMS_OTP_LENGTH="6" +GOTRUE_SMS_PROVIDER="twilio" +GOTRUE_SMS_TWILIO_ACCOUNT_SID="" +GOTRUE_SMS_TWILIO_AUTH_TOKEN="" +GOTRUE_SMS_TWILIO_MESSAGE_SERVICE_SID="" +GOTRUE_SMS_TEMPLATE="This is from supabase. Your code is {{ .Code }} ." +GOTRUE_SMS_MESSAGEBIRD_ACCESS_KEY="" +GOTRUE_SMS_MESSAGEBIRD_ORIGINATOR="" +GOTRUE_SMS_TEXTLOCAL_API_KEY="" +GOTRUE_SMS_TEXTLOCAL_SENDER="" +GOTRUE_SMS_VONAGE_API_KEY="" +GOTRUE_SMS_VONAGE_API_SECRET="" +GOTRUE_SMS_VONAGE_FROM="" + +# Captcha config +GOTRUE_SECURITY_CAPTCHA_ENABLED="false" +GOTRUE_SECURITY_CAPTCHA_PROVIDER="hcaptcha" +GOTRUE_SECURITY_CAPTCHA_SECRET="0x0000000000000000000000000000000000000000" +GOTRUE_SECURITY_CAPTCHA_TIMEOUT="10s" +GOTRUE_SESSION_KEY="" + +# SAML config +GOTRUE_EXTERNAL_SAML_ENABLED="true" +GOTRUE_EXTERNAL_SAML_METADATA_URL="" +GOTRUE_EXTERNAL_SAML_API_BASE="http://localhost:9999" +GOTRUE_EXTERNAL_SAML_NAME="auth0" +GOTRUE_EXTERNAL_SAML_SIGNING_CERT="" +GOTRUE_EXTERNAL_SAML_SIGNING_KEY="" + +# Additional Security config +GOTRUE_LOG_LEVEL="debug" +GOTRUE_SECURITY_REFRESH_TOKEN_ROTATION_ENABLED="false" +GOTRUE_SECURITY_REFRESH_TOKEN_REUSE_INTERVAL="0" +GOTRUE_SECURITY_UPDATE_PASSWORD_REQUIRE_REAUTHENTICATION="false" +GOTRUE_OPERATOR_TOKEN="unused-operator-token" +GOTRUE_RATE_LIMIT_HEADER="X-Forwarded-For" +GOTRUE_RATE_LIMIT_EMAIL_SENT="100" + +GOTRUE_MAX_VERIFIED_FACTORS=10 + +# Auth Hook Configuration +GOTRUE_HOOK_CUSTOM_ACCESS_TOKEN_ENABLED=false +GOTRUE_HOOK_CUSTOM_ACCESS_TOKEN_URI="" +# Only for HTTPS Hooks +GOTRUE_HOOK_CUSTOM_ACCESS_TOKEN_SECRET="" + +GOTRUE_HOOK_CUSTOM_SMS_PROVIDER_ENABLED=false +GOTRUE_HOOK_CUSTOM_SMS_PROVIDER_URI="" +# Only for HTTPS Hooks +GOTRUE_HOOK_CUSTOM_SMS_PROVIDER_SECRET="" + + +# Test OTP Config +GOTRUE_SMS_TEST_OTP=":, :..." +GOTRUE_SMS_TEST_OTP_VALID_UNTIL="2050-01-01T01:00:00Z" # (e.g. 2023-09-29T08:14:06Z) diff --git a/auth_v2.169.0/internal/security/captcha.go b/auth_v2.169.0/internal/security/captcha.go new file mode 100644 index 0000000..aeacb63 --- /dev/null +++ b/auth_v2.169.0/internal/security/captcha.go @@ -0,0 +1,101 @@ +package security + +import ( + "encoding/json" + "log" + "net/http" + "net/url" + "os" + "strconv" + "strings" + "time" + + "fmt" + + "github.com/pkg/errors" + "github.com/supabase/auth/internal/utilities" +) + +type GotrueRequest struct { + Security GotrueSecurity `json:"gotrue_meta_security"` +} + +type GotrueSecurity struct { + Token string `json:"captcha_token"` +} + +type VerificationResponse struct { + Success bool `json:"success"` + ErrorCodes []string `json:"error-codes"` + Hostname string `json:"hostname"` +} + +var Client *http.Client + +func init() { + var defaultTimeout time.Duration = time.Second * 10 + timeoutStr := os.Getenv("GOTRUE_SECURITY_CAPTCHA_TIMEOUT") + if timeoutStr != "" { + if timeout, err := time.ParseDuration(timeoutStr); err != nil { + log.Fatalf("error loading GOTRUE_SECURITY_CAPTCHA_TIMEOUT: %v", err.Error()) + } else if timeout != 0 { + defaultTimeout = timeout + } + } + + Client = &http.Client{Timeout: defaultTimeout} +} + +func VerifyRequest(requestBody *GotrueRequest, clientIP, secretKey, captchaProvider string) (VerificationResponse, error) { + captchaResponse := strings.TrimSpace(requestBody.Security.Token) + + if captchaResponse == "" { + return VerificationResponse{}, errors.New("no captcha response (captcha_token) found in request") + } + + captchaURL, err := GetCaptchaURL(captchaProvider) + if err != nil { + return VerificationResponse{}, err + } + + return verifyCaptchaCode(captchaResponse, secretKey, clientIP, captchaURL) +} + +func verifyCaptchaCode(token, secretKey, clientIP, captchaURL string) (VerificationResponse, error) { + data := url.Values{} + data.Set("secret", secretKey) + data.Set("response", token) + data.Set("remoteip", clientIP) + // TODO (darora): pipe through sitekey + + r, err := http.NewRequest("POST", captchaURL, strings.NewReader(data.Encode())) + if err != nil { + return VerificationResponse{}, errors.Wrap(err, "couldn't initialize request object for captcha check") + } + r.Header.Add("Content-Type", "application/x-www-form-urlencoded") + r.Header.Add("Content-Length", strconv.Itoa(len(data.Encode()))) + res, err := Client.Do(r) + if err != nil { + return VerificationResponse{}, errors.Wrap(err, "failed to verify captcha response") + } + defer utilities.SafeClose(res.Body) + + var verificationResponse VerificationResponse + + if err := json.NewDecoder(res.Body).Decode(&verificationResponse); err != nil { + return VerificationResponse{}, errors.Wrap(err, "failed to decode captcha response: not JSON") + } + + return verificationResponse, nil +} + +func GetCaptchaURL(captchaProvider string) (string, error) { + switch captchaProvider { + case "hcaptcha": + return "https://hcaptcha.com/siteverify", nil + case "turnstile": + return "https://challenges.cloudflare.com/turnstile/v0/siteverify", nil + default: + return "", fmt.Errorf("captcha Provider %q could not be found", captchaProvider) + } +} diff --git a/auth_v2.169.0/internal/storage/dial.go b/auth_v2.169.0/internal/storage/dial.go new file mode 100644 index 0000000..3ee9939 --- /dev/null +++ b/auth_v2.169.0/internal/storage/dial.go @@ -0,0 +1,192 @@ +package storage + +import ( + "context" + "database/sql" + "net/url" + "reflect" + "time" + + "github.com/XSAM/otelsql" + "github.com/gobuffalo/pop/v6" + "github.com/gobuffalo/pop/v6/columns" + "github.com/jmoiron/sqlx" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/conf" +) + +// Connection is the interface a storage provider must implement. +type Connection struct { + *pop.Connection +} + +// Dial will connect to that storage engine +func Dial(config *conf.GlobalConfiguration) (*Connection, error) { + if config.DB.Driver == "" && config.DB.URL != "" { + u, err := url.Parse(config.DB.URL) + if err != nil { + return nil, errors.Wrap(err, "parsing db connection url") + } + config.DB.Driver = u.Scheme + } + + driver := "" + if config.DB.Driver != "postgres" { + logrus.Warn("DEPRECATION NOTICE: only PostgreSQL is supported by Supabase's GoTrue, will be removed soon") + } else { + // pop v5 uses pgx as the default PostgreSQL driver + driver = "pgx" + } + + if driver != "" && (config.Tracing.Enabled || config.Metrics.Enabled) { + instrumentedDriver, err := otelsql.Register(driver) + if err != nil { + logrus.WithError(err).Errorf("unable to instrument sql driver %q for use with OpenTelemetry", driver) + } else { + logrus.Debugf("using %s as an instrumented driver for OpenTelemetry", instrumentedDriver) + + // sqlx needs to be informed that the new instrumented + // driver has the same semantics as the + // non-instrumented driver + sqlx.BindDriver(instrumentedDriver, sqlx.BindType(driver)) + + driver = instrumentedDriver + } + } + + options := make(map[string]string) + + if config.DB.HealthCheckPeriod != time.Duration(0) { + options["pool_health_check_period"] = config.DB.HealthCheckPeriod.String() + } + + if config.DB.ConnMaxIdleTime != time.Duration(0) { + options["pool_max_conn_idle_time"] = config.DB.ConnMaxIdleTime.String() + } + + db, err := pop.NewConnection(&pop.ConnectionDetails{ + Dialect: config.DB.Driver, + Driver: driver, + URL: config.DB.URL, + Pool: config.DB.MaxPoolSize, + IdlePool: config.DB.MaxIdlePoolSize, + ConnMaxLifetime: config.DB.ConnMaxLifetime, + ConnMaxIdleTime: config.DB.ConnMaxIdleTime, + Options: options, + }) + if err != nil { + return nil, errors.Wrap(err, "opening database connection") + } + if err := db.Open(); err != nil { + return nil, errors.Wrap(err, "checking database connection") + } + + if config.Metrics.Enabled { + registerOpenTelemetryDatabaseStats(db) + } + + return &Connection{db}, nil +} + +func registerOpenTelemetryDatabaseStats(db *pop.Connection) { + defer func() { + if rec := recover(); rec != nil { + logrus.WithField("error", rec).Error("registerOpenTelemetryDatabaseStats is not able to determine database object with reflection -- panicked") + } + }() + + dbval := reflect.Indirect(reflect.ValueOf(db.Store)) + dbfield := dbval.Field(0) + sqldbfield := reflect.Indirect(dbfield).Field(0) + + sqldb, ok := sqldbfield.Interface().(*sql.DB) + if !ok || sqldb == nil { + logrus.Error("registerOpenTelemetryDatabaseStats is not able to determine database object with reflection") + return + } + + if err := otelsql.RegisterDBStatsMetrics(sqldb); err != nil { + logrus.WithError(err).Error("unable to register OpenTelemetry stats metrics for databse") + } else { + logrus.Debug("registered OpenTelemetry stats metrics for database") + } +} + +type CommitWithError struct { + Err error +} + +func (e *CommitWithError) Error() string { + return e.Err.Error() +} + +func (e *CommitWithError) Cause() error { + return e.Err +} + +// NewCommitWithError creates an error that can be returned in a pop transaction +// without rolling back the transaction. This should only be used in cases where +// you want the transaction to commit but return an error message to the user. +func NewCommitWithError(err error) *CommitWithError { + return &CommitWithError{Err: err} +} + +func (c *Connection) Transaction(fn func(*Connection) error) error { + if c.TX == nil { + var returnErr error + if terr := c.Connection.Transaction(func(tx *pop.Connection) error { + err := fn(&Connection{tx}) + switch err.(type) { + case *CommitWithError: + returnErr = err + return nil + default: + return err + } + }); terr != nil { + // there exists a race condition when the context deadline is exceeded + // and whether the transaction has been committed or not + // e.g. if the context deadline has exceeded but the transaction has already been committed, + // it won't be possible to perform a rollback on the transaction since the transaction has been closed + if !errors.Is(terr, sql.ErrTxDone) { + return terr + } + } + return returnErr + } + return fn(c) +} + +// WithContext returns a new connection with an updated context. This is +// typically used for tracing as the context contains trace span information. +func (c *Connection) WithContext(ctx context.Context) *Connection { + return &Connection{c.Connection.WithContext(ctx)} +} + +func getExcludedColumns(model interface{}, includeColumns ...string) ([]string, error) { + sm := &pop.Model{Value: model} + st := reflect.TypeOf(model) + if st.Kind() == reflect.Ptr { + _ = st.Elem() + } + + // get all columns and remove included to get excluded set + cols := columns.ForStructWithAlias(model, sm.TableName(), sm.As, sm.IDField()) + for _, f := range includeColumns { + if _, ok := cols.Cols[f]; !ok { + return nil, errors.Errorf("Invalid column name %s", f) + } + cols.Remove(f) + } + + xcols := make([]string, 0, len(cols.Cols)) + for n := range cols.Cols { + // gobuffalo updates the updated_at column automatically + if n == "updated_at" { + continue + } + xcols = append(xcols, n) + } + return xcols, nil +} diff --git a/auth_v2.169.0/internal/storage/dial_test.go b/auth_v2.169.0/internal/storage/dial_test.go new file mode 100644 index 0000000..078b6d5 --- /dev/null +++ b/auth_v2.169.0/internal/storage/dial_test.go @@ -0,0 +1,60 @@ +package storage + +import ( + "errors" + "testing" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/conf" +) + +type TestUser struct { + ID uuid.UUID + Role string `db:"role"` + Other string `db:"othercol"` +} + +func TestGetExcludedColumns(t *testing.T) { + u := TestUser{} + cols, err := getExcludedColumns(u, "role") + require.NoError(t, err) + require.NotContains(t, cols, "role") + require.Contains(t, cols, "othercol") +} + +func TestGetExcludedColumns_InvalidName(t *testing.T) { + u := TestUser{} + _, err := getExcludedColumns(u, "adsf") + require.Error(t, err) +} + +func TestTransaction(t *testing.T) { + apiTestConfig := "../../hack/test.env" + config, err := conf.LoadGlobal(apiTestConfig) + require.NoError(t, err) + conn, err := Dial(config) + require.NoError(t, err) + require.NotNil(t, conn) + + defer func() { + // clean up the test table created + require.NoError(t, conn.RawQuery("drop table if exists test").Exec(), "Error removing table") + }() + + commitWithError := NewCommitWithError(errors.New("commit with error")) + err = conn.Transaction(func(tx *Connection) error { + require.NoError(t, tx.RawQuery("create table if not exists test()").Exec(), "Error saving creating test table") + return commitWithError + }) + require.Error(t, err) + require.ErrorIs(t, err, commitWithError) + + type TestData struct{} + + // check that transaction is still being committed despite returning an error above + data := []TestData{} + err = conn.RawQuery("select * from test").All(&data) + require.NoError(t, err) + require.Empty(t, data) +} diff --git a/auth_v2.169.0/internal/storage/helper.go b/auth_v2.169.0/internal/storage/helper.go new file mode 100644 index 0000000..2359984 --- /dev/null +++ b/auth_v2.169.0/internal/storage/helper.go @@ -0,0 +1,31 @@ +package storage + +import ( + "database/sql/driver" + "errors" +) + +type NullString string + +func (s *NullString) Scan(value interface{}) error { + if value == nil { + *s = "" + return nil + } + strVal, ok := value.(string) + if !ok { + return errors.New("column is not a string") + } + *s = NullString(strVal) + return nil +} +func (s NullString) Value() (driver.Value, error) { + if len(s) == 0 { // if nil or empty string + return nil, nil + } + return string(s), nil +} + +func (s NullString) String() string { + return string(s) +} diff --git a/auth_v2.169.0/internal/storage/sql.go b/auth_v2.169.0/internal/storage/sql.go new file mode 100644 index 0000000..2173411 --- /dev/null +++ b/auth_v2.169.0/internal/storage/sql.go @@ -0,0 +1,9 @@ +package storage + +func (conn *Connection) UpdateOnly(model interface{}, includeColumns ...string) error { + xcols, err := getExcludedColumns(model, includeColumns...) + if err != nil { + return err + } + return conn.Update(model, xcols...) +} diff --git a/auth_v2.169.0/internal/storage/test/db_setup.go b/auth_v2.169.0/internal/storage/test/db_setup.go new file mode 100644 index 0000000..8eeb099 --- /dev/null +++ b/auth_v2.169.0/internal/storage/test/db_setup.go @@ -0,0 +1,10 @@ +package test + +import ( + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" +) + +func SetupDBConnection(globalConfig *conf.GlobalConfiguration) (*storage.Connection, error) { + return storage.Dial(globalConfig) +} diff --git a/auth_v2.169.0/internal/utilities/context.go b/auth_v2.169.0/internal/utilities/context.go new file mode 100644 index 0000000..06aa74a --- /dev/null +++ b/auth_v2.169.0/internal/utilities/context.go @@ -0,0 +1,51 @@ +package utilities + +import ( + "context" + "sync" +) + +type contextKey string + +func (c contextKey) String() string { + return "gotrue api context key " + string(c) +} + +const ( + requestIDKey = contextKey("request_id") +) + +// WithRequestID adds the provided request ID to the context. +func WithRequestID(ctx context.Context, id string) context.Context { + return context.WithValue(ctx, requestIDKey, id) +} + +// GetRequestID reads the request ID from the context. +func GetRequestID(ctx context.Context) string { + obj := ctx.Value(requestIDKey) + if obj == nil { + return "" + } + + return obj.(string) +} + +// WaitForCleanup waits until all long-running goroutines shut +// down cleanly or until the provided context signals done. +func WaitForCleanup(ctx context.Context, wg *sync.WaitGroup) { + cleanupDone := make(chan struct{}) + + go func() { + defer close(cleanupDone) + + wg.Wait() + }() + + select { + case <-ctx.Done(): + return + + case <-cleanupDone: + return + } +} diff --git a/auth_v2.169.0/internal/utilities/hibpcache.go b/auth_v2.169.0/internal/utilities/hibpcache.go new file mode 100644 index 0000000..14c3fc3 --- /dev/null +++ b/auth_v2.169.0/internal/utilities/hibpcache.go @@ -0,0 +1,76 @@ +package utilities + +import ( + "context" + "sync" + + "github.com/bits-and-blooms/bloom/v3" +) + +const ( + // hibpHashLength is the length of a hex-encoded SHA1 hash. + hibpHashLength = 40 + // hibpHashPrefixLength is the length of the hashed password prefix. + hibpHashPrefixLength = 5 +) + +type HIBPBloomCache struct { + sync.RWMutex + + n uint + items uint + filter *bloom.BloomFilter +} + +func NewHIBPBloomCache(n uint, fp float64) *HIBPBloomCache { + cache := &HIBPBloomCache{ + n: n, + filter: bloom.NewWithEstimates(n, fp), + } + + return cache +} + +func (c *HIBPBloomCache) Cap() uint { + return c.filter.Cap() +} + +func (c *HIBPBloomCache) Add(ctx context.Context, prefix []byte, suffixes [][]byte) error { + c.Lock() + defer c.Unlock() + + c.items += uint(len(suffixes)) + + if c.items > (4*c.n)/5 { + // clear the filter if 80% full to keep the actual false + // positive rate low + c.filter.ClearAll() + + // reduce memory footprint when this happens + c.filter.BitSet().Compact() + + c.items = uint(len(suffixes)) + } + + var combined [hibpHashLength]byte + copy(combined[:], prefix) + + for _, suffix := range suffixes { + copy(combined[hibpHashPrefixLength:], suffix) + + c.filter.Add(combined[:]) + } + + return nil +} + +func (c *HIBPBloomCache) Contains(ctx context.Context, prefix, suffix []byte) (bool, error) { + var combined [hibpHashLength]byte + copy(combined[:], prefix) + copy(combined[hibpHashPrefixLength:], suffix) + + c.RLock() + defer c.RUnlock() + + return c.filter.Test(combined[:]), nil +} diff --git a/auth_v2.169.0/internal/utilities/io.go b/auth_v2.169.0/internal/utilities/io.go new file mode 100644 index 0000000..ab89b4c --- /dev/null +++ b/auth_v2.169.0/internal/utilities/io.go @@ -0,0 +1,13 @@ +package utilities + +import ( + "io" + + "github.com/sirupsen/logrus" +) + +func SafeClose(closer io.Closer) { + if err := closer.Close(); err != nil { + logrus.WithError(err).Warn("Close operation failed") + } +} diff --git a/auth_v2.169.0/internal/utilities/postgres.go b/auth_v2.169.0/internal/utilities/postgres.go new file mode 100644 index 0000000..4d7fde8 --- /dev/null +++ b/auth_v2.169.0/internal/utilities/postgres.go @@ -0,0 +1,76 @@ +package utilities + +import ( + "errors" + "strconv" + "strings" + + "github.com/jackc/pgconn" + "github.com/jackc/pgerrcode" +) + +// PostgresError is a custom error struct for marshalling Postgres errors to JSON. +type PostgresError struct { + Code string `json:"code"` + HttpStatusCode int `json:"-"` + Message string `json:"message"` + Hint string `json:"hint,omitempty"` + Detail string `json:"detail,omitempty"` +} + +// NewPostgresError returns a new PostgresError if the error was from a publicly +// accessible Postgres error. +func NewPostgresError(err error) *PostgresError { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && isPubliclyAccessiblePostgresError(pgErr.Code) { + return &PostgresError{ + Code: pgErr.Code, + HttpStatusCode: getHttpStatusCodeFromPostgresErrorCode(pgErr.Code), + Message: pgErr.Message, + Detail: pgErr.Detail, + Hint: pgErr.Hint, + } + } + + return nil +} +func (pg *PostgresError) IsUniqueConstraintViolated() bool { + // See https://www.postgresql.org/docs/current/errcodes-appendix.html for list of error codes + return pg.Code == "23505" +} + +// isPubliclyAccessiblePostgresError checks if the Postgres error should be +// made accessible. +func isPubliclyAccessiblePostgresError(code string) bool { + if len(code) != 5 { + return false + } + + // default response + return getHttpStatusCodeFromPostgresErrorCode(code) != 0 +} + +// getHttpStatusCodeFromPostgresErrorCode maps a Postgres error code to a HTTP +// status code. Returns 0 if the code doesn't map to a given postgres error code. +func getHttpStatusCodeFromPostgresErrorCode(code string) int { + if code == pgerrcode.RaiseException || + code == pgerrcode.IntegrityConstraintViolation || + code == pgerrcode.RestrictViolation || + code == pgerrcode.NotNullViolation || + code == pgerrcode.ForeignKeyViolation || + code == pgerrcode.UniqueViolation || + code == pgerrcode.CheckViolation || + code == pgerrcode.ExclusionViolation { + return 500 + } + + // Use custom HTTP status code if Postgres error was triggered with `PTXXX` + // code. This is consistent with PostgREST's behaviour as well. + if strings.HasPrefix(code, "PT") { + if httpStatusCode, err := strconv.ParseInt(code[2:], 10, 0); err == nil { + return int(httpStatusCode) + } + } + + return 0 +} diff --git a/auth_v2.169.0/internal/utilities/request.go b/auth_v2.169.0/internal/utilities/request.go new file mode 100644 index 0000000..b6b8697 --- /dev/null +++ b/auth_v2.169.0/internal/utilities/request.go @@ -0,0 +1,117 @@ +package utilities + +import ( + "bytes" + "io" + "net" + "net/http" + "net/url" + "strings" + + "github.com/supabase/auth/internal/conf" +) + +// GetIPAddress returns the real IP address of the HTTP request. It parses the +// X-Forwarded-For header. +func GetIPAddress(r *http.Request) string { + if r.Header != nil { + xForwardedFor := r.Header.Get("X-Forwarded-For") + if xForwardedFor != "" { + ips := strings.Split(xForwardedFor, ",") + for i := range ips { + ips[i] = strings.TrimSpace(ips[i]) + } + + for _, ip := range ips { + if ip != "" { + parsed := net.ParseIP(ip) + if parsed == nil { + continue + } + + return parsed.String() + } + } + } + } + + ipPort := r.RemoteAddr + ip, _, err := net.SplitHostPort(ipPort) + if err != nil { + return ipPort + } + + return ip +} + +// GetBodyBytes reads the whole request body properly into a byte array. +func GetBodyBytes(req *http.Request) ([]byte, error) { + if req.Body == nil || req.Body == http.NoBody { + return nil, nil + } + + originalBody := req.Body + defer SafeClose(originalBody) + + buf, err := io.ReadAll(originalBody) + if err != nil { + return nil, err + } + + req.Body = io.NopCloser(bytes.NewReader(buf)) + + return buf, nil +} + +func GetReferrer(r *http.Request, config *conf.GlobalConfiguration) string { + // try get redirect url from query or post data first + reqref := getRedirectTo(r) + if IsRedirectURLValid(config, reqref) { + return reqref + } + + // instead try referrer header value + reqref = r.Referer() + if IsRedirectURLValid(config, reqref) { + return reqref + } + + return config.SiteURL +} + +func IsRedirectURLValid(config *conf.GlobalConfiguration, redirectURL string) bool { + if redirectURL == "" { + return false + } + + base, berr := url.Parse(config.SiteURL) + refurl, rerr := url.Parse(redirectURL) + + // As long as the referrer came from the site, we will redirect back there + if berr == nil && rerr == nil && base.Hostname() == refurl.Hostname() { + return true + } + + // For case when user came from mobile app or other permitted resource - redirect back + for _, pattern := range config.URIAllowListMap { + if pattern.Match(redirectURL) { + return true + } + } + + return false +} + +// getRedirectTo tries extract redirect url from header or from query params +func getRedirectTo(r *http.Request) (reqref string) { + reqref = r.Header.Get("redirect_to") + if reqref != "" { + return + } + + if err := r.ParseForm(); err == nil { + reqref = r.Form.Get("redirect_to") + } + + return +} diff --git a/auth_v2.169.0/internal/utilities/request_test.go b/auth_v2.169.0/internal/utilities/request_test.go new file mode 100644 index 0000000..6704e39 --- /dev/null +++ b/auth_v2.169.0/internal/utilities/request_test.go @@ -0,0 +1,134 @@ +package utilities + +import ( + "net/http" + "net/http/httptest" + tst "testing" + + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/conf" +) + +func TestGetIPAddress(t *tst.T) { + examples := []func(r *http.Request) string{ + func(r *http.Request) string { + r.Header = nil + r.RemoteAddr = "127.0.0.1:8080" + + return "127.0.0.1" + }, + + func(r *http.Request) string { + r.Header = nil + r.RemoteAddr = "incorrect" + + return "incorrect" + }, + + func(r *http.Request) string { + r.Header = make(http.Header) + r.RemoteAddr = "127.0.0.1:8080" + + return "127.0.0.1" + }, + + func(r *http.Request) string { + r.Header = make(http.Header) + r.RemoteAddr = "[::1]:8080" + + return "::1" + }, + + func(r *http.Request) string { + r.Header = make(http.Header) + r.RemoteAddr = "127.0.0.1:8080" + r.Header.Add("X-Forwarded-For", "127.0.0.2") + + return "127.0.0.2" + }, + + func(r *http.Request) string { + r.Header = make(http.Header) + r.RemoteAddr = "127.0.0.1:8080" + r.Header.Add("X-Forwarded-For", "127.0.0.2") + + return "127.0.0.2" + }, + + func(r *http.Request) string { + r.Header = make(http.Header) + r.RemoteAddr = "127.0.0.1:8080" + r.Header.Add("X-Forwarded-For", "127.0.0.2,") + + return "127.0.0.2" + }, + + func(r *http.Request) string { + r.Header = make(http.Header) + r.RemoteAddr = "127.0.0.1:8080" + r.Header.Add("X-Forwarded-For", "127.0.0.2,127.0.0.3") + + return "127.0.0.2" + }, + + func(r *http.Request) string { + r.Header = make(http.Header) + r.RemoteAddr = "127.0.0.1:8080" + r.Header.Add("X-Forwarded-For", "::1,127.0.0.2") + + return "::1" + }, + } + + for _, example := range examples { + req := &http.Request{} + expected := example(req) + + require.Equal(t, GetIPAddress(req), expected) + } +} + +func TestGetReferrer(t *tst.T) { + config := conf.GlobalConfiguration{ + SiteURL: "https://example.com", + URIAllowList: []string{"http://localhost:8000/*"}, + JWT: conf.JWTConfiguration{ + Secret: "testsecret", + }, + } + require.NoError(t, config.ApplyDefaults()) + cases := []struct { + desc string + redirectURL string + expected string + }{ + { + desc: "valid redirect url", + redirectURL: "http://localhost:8000/path", + expected: "http://localhost:8000/path", + }, + { + desc: "invalid redirect url", + redirectURL: "http://localhost:3000", + expected: config.SiteURL, + }, + { + desc: "no / separator", + redirectURL: "http://localhost:8000", + expected: config.SiteURL, + }, + { + desc: "* respects separator", + redirectURL: "http://localhost:8000/path/to/page", + expected: config.SiteURL, + }, + } + + for _, c := range cases { + t.Run(c.desc, func(t *tst.T) { + r := httptest.NewRequest("GET", "http://localhost?redirect_to="+c.redirectURL, nil) + referrer := GetReferrer(r, &config) + require.Equal(t, c.expected, referrer) + }) + } +} diff --git a/auth_v2.169.0/internal/utilities/version.go b/auth_v2.169.0/internal/utilities/version.go new file mode 100644 index 0000000..b3ba95a --- /dev/null +++ b/auth_v2.169.0/internal/utilities/version.go @@ -0,0 +1,4 @@ +package utilities + +// Version is git commit or release tag from which this binary was built. +var Version string diff --git a/auth_v2.169.0/main.go b/auth_v2.169.0/main.go new file mode 100644 index 0000000..7455193 --- /dev/null +++ b/auth_v2.169.0/main.go @@ -0,0 +1,68 @@ +package main + +import ( + "context" + "embed" + "os/signal" + "sync" + "syscall" + "time" + + "github.com/sirupsen/logrus" + "github.com/supabase/auth/cmd" + "github.com/supabase/auth/internal/observability" +) + +//go:embed migrations/* +var embeddedMigrations embed.FS + +func init() { + logrus.SetFormatter(&logrus.JSONFormatter{}) +} + +func main() { + cmd.EmbeddedMigrations = embeddedMigrations + + execCtx, execCancel := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGHUP, syscall.SIGINT) + defer execCancel() + + go func() { + <-execCtx.Done() + logrus.Info("received graceful shutdown signal") + }() + + // command is expected to obey the cancellation signal on execCtx and + // block while it is running + if err := cmd.RootCommand().ExecuteContext(execCtx); err != nil { + logrus.WithError(err).Fatal(err) + } + + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), time.Minute) + defer shutdownCancel() + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + + // wait for profiler, metrics and trace exporters to shut down gracefully + observability.WaitForCleanup(shutdownCtx) + }() + + cleanupDone := make(chan struct{}) + go func() { + defer close(cleanupDone) + wg.Wait() + }() + + select { + case <-shutdownCtx.Done(): + // cleanup timed out + return + + case <-cleanupDone: + // cleanup finished before timing out + return + } +} diff --git a/auth_v2.169.0/migrations/00_init_auth_schema.up.sql b/auth_v2.169.0/migrations/00_init_auth_schema.up.sql new file mode 100644 index 0000000..a040095 --- /dev/null +++ b/auth_v2.169.0/migrations/00_init_auth_schema.up.sql @@ -0,0 +1,88 @@ +-- auth.users definition + +CREATE TABLE IF NOT EXISTS {{ index .Options "Namespace" }}.users ( + instance_id uuid NULL, + id uuid NOT NULL UNIQUE, + aud varchar(255) NULL, + "role" varchar(255) NULL, + email varchar(255) NULL UNIQUE, + encrypted_password varchar(255) NULL, + confirmed_at timestamptz NULL, + invited_at timestamptz NULL, + confirmation_token varchar(255) NULL, + confirmation_sent_at timestamptz NULL, + recovery_token varchar(255) NULL, + recovery_sent_at timestamptz NULL, + email_change_token varchar(255) NULL, + email_change varchar(255) NULL, + email_change_sent_at timestamptz NULL, + last_sign_in_at timestamptz NULL, + raw_app_meta_data jsonb NULL, + raw_user_meta_data jsonb NULL, + is_super_admin bool NULL, + created_at timestamptz NULL, + updated_at timestamptz NULL, + CONSTRAINT users_pkey PRIMARY KEY (id) +); +CREATE INDEX IF NOT EXISTS users_instance_id_email_idx ON {{ index .Options "Namespace" }}.users USING btree (instance_id, email); +CREATE INDEX IF NOT EXISTS users_instance_id_idx ON {{ index .Options "Namespace" }}.users USING btree (instance_id); +comment on table {{ index .Options "Namespace" }}.users is 'Auth: Stores user login data within a secure schema.'; + +-- auth.refresh_tokens definition + +CREATE TABLE IF NOT EXISTS {{ index .Options "Namespace" }}.refresh_tokens ( + instance_id uuid NULL, + id bigserial NOT NULL, + "token" varchar(255) NULL, + user_id varchar(255) NULL, + revoked bool NULL, + created_at timestamptz NULL, + updated_at timestamptz NULL, + CONSTRAINT refresh_tokens_pkey PRIMARY KEY (id) +); +CREATE INDEX IF NOT EXISTS refresh_tokens_instance_id_idx ON {{ index .Options "Namespace" }}.refresh_tokens USING btree (instance_id); +CREATE INDEX IF NOT EXISTS refresh_tokens_instance_id_user_id_idx ON {{ index .Options "Namespace" }}.refresh_tokens USING btree (instance_id, user_id); +CREATE INDEX IF NOT EXISTS refresh_tokens_token_idx ON {{ index .Options "Namespace" }}.refresh_tokens USING btree (token); +comment on table {{ index .Options "Namespace" }}.refresh_tokens is 'Auth: Store of tokens used to refresh JWT tokens once they expire.'; + +-- auth.instances definition + +CREATE TABLE IF NOT EXISTS {{ index .Options "Namespace" }}.instances ( + id uuid NOT NULL, + uuid uuid NULL, + raw_base_config text NULL, + created_at timestamptz NULL, + updated_at timestamptz NULL, + CONSTRAINT instances_pkey PRIMARY KEY (id) +); +comment on table {{ index .Options "Namespace" }}.instances is 'Auth: Manages users across multiple sites.'; + +-- auth.audit_log_entries definition + +CREATE TABLE IF NOT EXISTS {{ index .Options "Namespace" }}.audit_log_entries ( + instance_id uuid NULL, + id uuid NOT NULL, + payload json NULL, + created_at timestamptz NULL, + CONSTRAINT audit_log_entries_pkey PRIMARY KEY (id) +); +CREATE INDEX IF NOT EXISTS audit_logs_instance_id_idx ON {{ index .Options "Namespace" }}.audit_log_entries USING btree (instance_id); +comment on table {{ index .Options "Namespace" }}.audit_log_entries is 'Auth: Audit trail for user actions.'; + +-- auth.schema_migrations definition + +CREATE TABLE IF NOT EXISTS {{ index .Options "Namespace" }}.schema_migrations ( + "version" varchar(255) NOT NULL, + CONSTRAINT schema_migrations_pkey PRIMARY KEY ("version") +); +comment on table {{ index .Options "Namespace" }}.schema_migrations is 'Auth: Manages updates to the auth system.'; + +-- Gets the User ID from the request cookie +create or replace function {{ index .Options "Namespace" }}.uid() returns uuid as $$ + select nullif(current_setting('request.jwt.claim.sub', true), '')::uuid; +$$ language sql stable; + +-- Gets the User ID from the request cookie +create or replace function {{ index .Options "Namespace" }}.role() returns text as $$ + select nullif(current_setting('request.jwt.claim.role', true), '')::text; +$$ language sql stable; diff --git a/auth_v2.169.0/migrations/20210710035447_alter_users.up.sql b/auth_v2.169.0/migrations/20210710035447_alter_users.up.sql new file mode 100644 index 0000000..fc8de12 --- /dev/null +++ b/auth_v2.169.0/migrations/20210710035447_alter_users.up.sql @@ -0,0 +1,19 @@ +-- alter user schema + +ALTER TABLE {{ index .Options "Namespace" }}.users +ADD COLUMN IF NOT EXISTS phone VARCHAR(15) NULL UNIQUE DEFAULT NULL, +ADD COLUMN IF NOT EXISTS phone_confirmed_at timestamptz NULL DEFAULT NULL, +ADD COLUMN IF NOT EXISTS phone_change VARCHAR(15) NULL DEFAULT '', +ADD COLUMN IF NOT EXISTS phone_change_token VARCHAR(255) NULL DEFAULT '', +ADD COLUMN IF NOT EXISTS phone_change_sent_at timestamptz NULL DEFAULT NULL; + +DO $$ +BEGIN + IF NOT EXISTS(SELECT * + FROM information_schema.columns + WHERE table_schema = '{{ index .Options "Namespace" }}' and table_name='users' and column_name='email_confirmed_at') + THEN + ALTER TABLE "{{ index .Options "Namespace" }}"."users" RENAME COLUMN "confirmed_at" TO "email_confirmed_at"; + END IF; +END $$; + diff --git a/auth_v2.169.0/migrations/20210722035447_adds_confirmed_at.up.sql b/auth_v2.169.0/migrations/20210722035447_adds_confirmed_at.up.sql new file mode 100644 index 0000000..aabd42e --- /dev/null +++ b/auth_v2.169.0/migrations/20210722035447_adds_confirmed_at.up.sql @@ -0,0 +1,4 @@ +-- adds confirmed at + +ALTER TABLE {{ index .Options "Namespace" }}.users +ADD COLUMN IF NOT EXISTS confirmed_at timestamptz GENERATED ALWAYS AS (LEAST (users.email_confirmed_at, users.phone_confirmed_at)) STORED; diff --git a/auth_v2.169.0/migrations/20210730183235_add_email_change_confirmed.up.sql b/auth_v2.169.0/migrations/20210730183235_add_email_change_confirmed.up.sql new file mode 100644 index 0000000..dc92c9c --- /dev/null +++ b/auth_v2.169.0/migrations/20210730183235_add_email_change_confirmed.up.sql @@ -0,0 +1,15 @@ +-- adds email_change_confirmed + +ALTER TABLE {{ index .Options "Namespace" }}.users +ADD COLUMN IF NOT EXISTS email_change_token_current varchar(255) null DEFAULT '', +ADD COLUMN IF NOT EXISTS email_change_confirm_status smallint DEFAULT 0 CHECK (email_change_confirm_status >= 0 AND email_change_confirm_status <= 2); + +DO $$ +BEGIN + IF NOT EXISTS(SELECT * + FROM information_schema.columns + WHERE table_schema = '{{ index .Options "Namespace" }}' and table_name='users' and column_name='email_change_token_new') + THEN + ALTER TABLE "{{ index .Options "Namespace" }}"."users" RENAME COLUMN "email_change_token" TO "email_change_token_new"; + END IF; +END $$; diff --git a/auth_v2.169.0/migrations/20210909172000_create_identities_table.up.sql b/auth_v2.169.0/migrations/20210909172000_create_identities_table.up.sql new file mode 100644 index 0000000..2f3a535 --- /dev/null +++ b/auth_v2.169.0/migrations/20210909172000_create_identities_table.up.sql @@ -0,0 +1,14 @@ +-- adds identities table + +CREATE TABLE IF NOT EXISTS {{ index .Options "Namespace" }}.identities ( + id text NOT NULL, + user_id uuid NOT NULL, + identity_data JSONB NOT NULL, + provider text NOT NULL, + last_sign_in_at timestamptz NULL, + created_at timestamptz NULL, + updated_at timestamptz NULL, + CONSTRAINT identities_pkey PRIMARY KEY (provider, id), + CONSTRAINT identities_user_id_fkey FOREIGN KEY (user_id) REFERENCES {{ index .Options "Namespace" }}.users(id) ON DELETE CASCADE +); +COMMENT ON TABLE {{ index .Options "Namespace" }}.identities is 'Auth: Stores identities associated to a user.'; diff --git a/auth_v2.169.0/migrations/20210927181326_add_refresh_token_parent.up.sql b/auth_v2.169.0/migrations/20210927181326_add_refresh_token_parent.up.sql new file mode 100644 index 0000000..a2b1c73 --- /dev/null +++ b/auth_v2.169.0/migrations/20210927181326_add_refresh_token_parent.up.sql @@ -0,0 +1,24 @@ +-- adds parent column + +ALTER TABLE {{ index .Options "Namespace" }}.refresh_tokens +ADD COLUMN IF NOT EXISTS parent varchar(255) NULL; + +DO $$ +BEGIN + IF NOT EXISTS(SELECT * + FROM information_schema.constraint_column_usage + WHERE table_schema = '{{ index .Options "Namespace" }}' and table_name='refresh_tokens' and constraint_name='refresh_tokens_token_unique') + THEN + ALTER TABLE "{{ index .Options "Namespace" }}"."refresh_tokens" ADD CONSTRAINT refresh_tokens_token_unique UNIQUE ("token"); + END IF; + + IF NOT EXISTS(SELECT * + FROM information_schema.constraint_column_usage + WHERE table_schema = '{{ index .Options "Namespace" }}' and table_name='refresh_tokens' and constraint_name='refresh_tokens_parent_fkey') + THEN + ALTER TABLE "{{ index .Options "Namespace" }}"."refresh_tokens" ADD CONSTRAINT refresh_tokens_parent_fkey FOREIGN KEY (parent) REFERENCES {{ index .Options "Namespace" }}.refresh_tokens("token"); + END IF; + + CREATE INDEX IF NOT EXISTS refresh_tokens_parent_idx ON "{{ index .Options "Namespace" }}"."refresh_tokens" USING btree (parent); +END $$; + diff --git a/auth_v2.169.0/migrations/20211122151130_create_user_id_idx.up.sql b/auth_v2.169.0/migrations/20211122151130_create_user_id_idx.up.sql new file mode 100644 index 0000000..d259aae --- /dev/null +++ b/auth_v2.169.0/migrations/20211122151130_create_user_id_idx.up.sql @@ -0,0 +1,3 @@ +-- create index on identities.user_id + +CREATE INDEX IF NOT EXISTS identities_user_id_idx ON "{{ index .Options "Namespace" }}".identities using btree (user_id); diff --git a/auth_v2.169.0/migrations/20211124214934_update_auth_functions.up.sql b/auth_v2.169.0/migrations/20211124214934_update_auth_functions.up.sql new file mode 100644 index 0000000..2fb784b --- /dev/null +++ b/auth_v2.169.0/migrations/20211124214934_update_auth_functions.up.sql @@ -0,0 +1,34 @@ +-- update auth functions + +create or replace function {{ index .Options "Namespace" }}.uid() +returns uuid +language sql stable +as $$ + select + coalesce( + current_setting('request.jwt.claim.sub', true), + (current_setting('request.jwt.claims', true)::jsonb ->> 'sub') + )::uuid +$$; + +create or replace function {{ index .Options "Namespace" }}.role() +returns text +language sql stable +as $$ + select + coalesce( + current_setting('request.jwt.claim.role', true), + (current_setting('request.jwt.claims', true)::jsonb ->> 'role') + )::text +$$; + +create or replace function {{ index .Options "Namespace" }}.email() +returns text +language sql stable +as $$ + select + coalesce( + current_setting('request.jwt.claim.email', true), + (current_setting('request.jwt.claims', true)::jsonb ->> 'email') + )::text +$$; diff --git a/auth_v2.169.0/migrations/20211202183645_update_auth_uid.up.sql b/auth_v2.169.0/migrations/20211202183645_update_auth_uid.up.sql new file mode 100644 index 0000000..3ecadfd --- /dev/null +++ b/auth_v2.169.0/migrations/20211202183645_update_auth_uid.up.sql @@ -0,0 +1,15 @@ +-- update auth.uid() + +create or replace function {{ index .Options "Namespace" }}.uid() +returns uuid +language sql stable +as $$ + select + nullif( + coalesce( + current_setting('request.jwt.claim.sub', true), + (current_setting('request.jwt.claims', true)::jsonb ->> 'sub') + ), + '' + )::uuid +$$; diff --git a/auth_v2.169.0/migrations/20220114185221_update_user_idx.up.sql b/auth_v2.169.0/migrations/20220114185221_update_user_idx.up.sql new file mode 100644 index 0000000..02fe76a --- /dev/null +++ b/auth_v2.169.0/migrations/20220114185221_update_user_idx.up.sql @@ -0,0 +1,4 @@ +-- updates users_instance_id_email_idx definition + +DROP INDEX IF EXISTS users_instance_id_email_idx; +CREATE INDEX IF NOT EXISTS users_instance_id_email_idx on "{{ index .Options "Namespace" }}".users using btree (instance_id, lower(email)); diff --git a/auth_v2.169.0/migrations/20220114185340_add_banned_until.up.sql b/auth_v2.169.0/migrations/20220114185340_add_banned_until.up.sql new file mode 100644 index 0000000..7530a7c --- /dev/null +++ b/auth_v2.169.0/migrations/20220114185340_add_banned_until.up.sql @@ -0,0 +1,4 @@ +-- adds banned_until column + +ALTER TABLE {{ index .Options "Namespace" }}.users +ADD COLUMN IF NOT EXISTS banned_until timestamptz NULL; diff --git a/auth_v2.169.0/migrations/20220224000811_update_auth_functions.up.sql b/auth_v2.169.0/migrations/20220224000811_update_auth_functions.up.sql new file mode 100644 index 0000000..4be4237 --- /dev/null +++ b/auth_v2.169.0/migrations/20220224000811_update_auth_functions.up.sql @@ -0,0 +1,34 @@ +-- update auth functions + +create or replace function {{ index .Options "Namespace" }}.uid() +returns uuid +language sql stable +as $$ + select + coalesce( + nullif(current_setting('request.jwt.claim.sub', true), ''), + (nullif(current_setting('request.jwt.claims', true), '')::jsonb ->> 'sub') + )::uuid +$$; + +create or replace function {{ index .Options "Namespace" }}.role() +returns text +language sql stable +as $$ + select + coalesce( + nullif(current_setting('request.jwt.claim.role', true), ''), + (nullif(current_setting('request.jwt.claims', true), '')::jsonb ->> 'role') + )::text +$$; + +create or replace function {{ index .Options "Namespace" }}.email() +returns text +language sql stable +as $$ + select + coalesce( + nullif(current_setting('request.jwt.claim.email', true), ''), + (nullif(current_setting('request.jwt.claims', true), '')::jsonb ->> 'email') + )::text +$$; diff --git a/auth_v2.169.0/migrations/20220323170000_add_user_reauthentication.up.sql b/auth_v2.169.0/migrations/20220323170000_add_user_reauthentication.up.sql new file mode 100644 index 0000000..277dbdb --- /dev/null +++ b/auth_v2.169.0/migrations/20220323170000_add_user_reauthentication.up.sql @@ -0,0 +1,5 @@ +-- adds reauthentication_token and reauthentication_sent_at + +ALTER TABLE {{ index .Options "Namespace" }}.users +ADD COLUMN IF NOT EXISTS reauthentication_token varchar(255) null default '', +ADD COLUMN IF NOT EXISTS reauthentication_sent_at timestamptz null default null; diff --git a/auth_v2.169.0/migrations/20220429102000_add_unique_idx.up.sql b/auth_v2.169.0/migrations/20220429102000_add_unique_idx.up.sql new file mode 100644 index 0000000..9d7644d --- /dev/null +++ b/auth_v2.169.0/migrations/20220429102000_add_unique_idx.up.sql @@ -0,0 +1,14 @@ +-- add partial unique indices to confirmation_token, recovery_token, email_change_token_current, email_change_token_new, phone_change_token, reauthentication_token +-- ignores partial unique index creation on fields which contain empty strings, whitespaces or purely numeric otps + +DROP INDEX IF EXISTS confirmation_token_idx; +DROP INDEX IF EXISTS recovery_token_idx; +DROP INDEX IF EXISTS email_change_token_current_idx; +DROP INDEX IF EXISTS email_change_token_new_idx; +DROP INDEX IF EXISTS reauthentication_token_idx; + +CREATE UNIQUE INDEX IF NOT EXISTS confirmation_token_idx ON {{ index .Options "Namespace" }}.users USING btree (confirmation_token) WHERE confirmation_token !~ '^[0-9 ]*$'; +CREATE UNIQUE INDEX IF NOT EXISTS recovery_token_idx ON {{ index .Options "Namespace" }}.users USING btree (recovery_token) WHERE recovery_token !~ '^[0-9 ]*$'; +CREATE UNIQUE INDEX IF NOT EXISTS email_change_token_current_idx ON {{ index .Options "Namespace" }}.users USING btree (email_change_token_current) WHERE email_change_token_current !~ '^[0-9 ]*$'; +CREATE UNIQUE INDEX IF NOT EXISTS email_change_token_new_idx ON {{ index .Options "Namespace" }}.users USING btree (email_change_token_new) WHERE email_change_token_new !~ '^[0-9 ]*$'; +CREATE UNIQUE INDEX IF NOT EXISTS reauthentication_token_idx ON {{ index .Options "Namespace" }}.users USING btree (reauthentication_token) WHERE reauthentication_token !~ '^[0-9 ]*$'; diff --git a/auth_v2.169.0/migrations/20220531120530_add_auth_jwt_function.up.sql b/auth_v2.169.0/migrations/20220531120530_add_auth_jwt_function.up.sql new file mode 100644 index 0000000..11f84e8 --- /dev/null +++ b/auth_v2.169.0/migrations/20220531120530_add_auth_jwt_function.up.sql @@ -0,0 +1,16 @@ +-- add auth.jwt function + +comment on function {{ index .Options "Namespace" }}.uid() is 'Deprecated. Use auth.jwt() -> ''sub'' instead.'; +comment on function {{ index .Options "Namespace" }}.role() is 'Deprecated. Use auth.jwt() -> ''role'' instead.'; +comment on function {{ index .Options "Namespace" }}.email() is 'Deprecated. Use auth.jwt() -> ''email'' instead.'; + +create or replace function {{ index .Options "Namespace" }}.jwt() +returns jsonb +language sql stable +as $$ + select + coalesce( + nullif(current_setting('request.jwt.claim', true), ''), + nullif(current_setting('request.jwt.claims', true), '') + )::jsonb +$$; diff --git a/auth_v2.169.0/migrations/20220614074223_add_ip_address_to_audit_log.postgres.up.sql b/auth_v2.169.0/migrations/20220614074223_add_ip_address_to_audit_log.postgres.up.sql new file mode 100644 index 0000000..a1a66b4 --- /dev/null +++ b/auth_v2.169.0/migrations/20220614074223_add_ip_address_to_audit_log.postgres.up.sql @@ -0,0 +1,3 @@ +-- Add IP Address to audit log +ALTER TABLE {{ index .Options "Namespace" }}.audit_log_entries +ADD COLUMN IF NOT EXISTS ip_address VARCHAR(64) NOT NULL DEFAULT ''; diff --git a/auth_v2.169.0/migrations/20220811173540_add_sessions_table.up.sql b/auth_v2.169.0/migrations/20220811173540_add_sessions_table.up.sql new file mode 100644 index 0000000..c16ef3c --- /dev/null +++ b/auth_v2.169.0/migrations/20220811173540_add_sessions_table.up.sql @@ -0,0 +1,23 @@ +-- Add session_id column to refresh_tokens table +create table if not exists {{ index .Options "Namespace" }}.sessions ( + id uuid not null, + user_id uuid not null, + created_at timestamptz null, + updated_at timestamptz null, + constraint sessions_pkey primary key (id), + constraint sessions_user_id_fkey foreign key (user_id) references {{ index .Options "Namespace" }}.users(id) on delete cascade +); +comment on table {{ index .Options "Namespace" }}.sessions is 'Auth: Stores session data associated to a user.'; + +alter table {{ index .Options "Namespace" }}.refresh_tokens +add column if not exists session_id uuid null; + +do $$ +begin + if not exists(select * + from information_schema.constraint_column_usage + where table_schema = '{{ index .Options "Namespace" }}' and table_name='sessions' and constraint_name='refresh_tokens_session_id_fkey') + then + alter table "{{ index .Options "Namespace" }}"."refresh_tokens" add constraint refresh_tokens_session_id_fkey foreign key (session_id) references {{ index .Options "Namespace" }}.sessions(id) on delete cascade; + end if; +END $$; diff --git a/auth_v2.169.0/migrations/20221003041349_add_mfa_schema.up.sql b/auth_v2.169.0/migrations/20221003041349_add_mfa_schema.up.sql new file mode 100644 index 0000000..a44654a --- /dev/null +++ b/auth_v2.169.0/migrations/20221003041349_add_mfa_schema.up.sql @@ -0,0 +1,50 @@ +-- see: https://stackoverflow.com/questions/7624919/check-if-a-user-defined-type-already-exists-in-postgresql/48382296#48382296 +do $$ begin + create type factor_type as enum('totp', 'webauthn'); + create type factor_status as enum('unverified', 'verified'); + create type aal_level as enum('aal1', 'aal2', 'aal3'); +exception + when duplicate_object then null; +end $$; + +-- auth.mfa_factors definition +create table if not exists {{ index .Options "Namespace" }}.mfa_factors( + id uuid not null, + user_id uuid not null, + friendly_name text null, + factor_type factor_type not null, + status factor_status not null, + created_at timestamptz not null, + updated_at timestamptz not null, + secret text null, + constraint mfa_factors_pkey primary key(id), + constraint mfa_factors_user_id_fkey foreign key (user_id) references {{ index .Options "Namespace" }}.users(id) on delete cascade +); +comment on table {{ index .Options "Namespace" }}.mfa_factors is 'auth: stores metadata about factors'; + +create unique index if not exists mfa_factors_user_friendly_name_unique on {{ index .Options "Namespace" }}.mfa_factors (friendly_name, user_id) where trim(friendly_name) <> ''; + +-- auth.mfa_challenges definition +create table if not exists {{ index .Options "Namespace" }}.mfa_challenges( + id uuid not null, + factor_id uuid not null, + created_at timestamptz not null, + verified_at timestamptz null, + ip_address inet not null, + constraint mfa_challenges_pkey primary key (id), + constraint mfa_challenges_auth_factor_id_fkey foreign key (factor_id) references {{ index .Options "Namespace" }}.mfa_factors(id) on delete cascade +); +comment on table {{ index .Options "Namespace" }}.mfa_challenges is 'auth: stores metadata about challenge requests made'; + + + +-- add factor_id and amr claims to session +create table if not exists {{ index .Options "Namespace" }}.mfa_amr_claims( + session_id uuid not null, + created_at timestamptz not null, + updated_at timestamptz not null, + authentication_method text not null, + constraint mfa_amr_claims_session_id_authentication_method_pkey unique(session_id, authentication_method), + constraint mfa_amr_claims_session_id_fkey foreign key(session_id) references {{ index .Options "Namespace" }}.sessions(id) on delete cascade +); +comment on table {{ index .Options "Namespace" }}.mfa_amr_claims is 'auth: stores authenticator method reference claims for multi factor authentication'; diff --git a/auth_v2.169.0/migrations/20221003041400_add_aal_and_factor_id_to_sessions.up.sql b/auth_v2.169.0/migrations/20221003041400_add_aal_and_factor_id_to_sessions.up.sql new file mode 100644 index 0000000..cc8a209 --- /dev/null +++ b/auth_v2.169.0/migrations/20221003041400_add_aal_and_factor_id_to_sessions.up.sql @@ -0,0 +1,3 @@ +-- add factor_id to sessions + alter table {{ index .Options "Namespace" }}.sessions add column if not exists factor_id uuid null; + alter table {{ index .Options "Namespace" }}.sessions add column if not exists aal aal_level null; diff --git a/auth_v2.169.0/migrations/20221011041400_add_mfa_indexes.up.sql b/auth_v2.169.0/migrations/20221011041400_add_mfa_indexes.up.sql new file mode 100644 index 0000000..def57a2 --- /dev/null +++ b/auth_v2.169.0/migrations/20221011041400_add_mfa_indexes.up.sql @@ -0,0 +1,19 @@ +alter table {{ index .Options "Namespace" }}.mfa_amr_claims + add column if not exists id uuid not null; + +do $$ +begin + if not exists + (select constraint_name + from information_schema.table_constraints + where table_schema = '{{ index .Options "Namespace" }}' + and table_name = 'mfa_amr_claims' + and constraint_name = 'amr_id_pk') + then + alter table {{ index .Options "Namespace" }}.mfa_amr_claims add constraint amr_id_pk primary key(id); + end if; +end $$; + +create index if not exists user_id_created_at_idx on {{ index .Options "Namespace" }}.sessions (user_id, created_at); +create index if not exists factor_id_created_at_idx on {{ index .Options "Namespace" }}.mfa_factors (user_id, created_at); + diff --git a/auth_v2.169.0/migrations/20221020193600_add_sessions_user_id_index.up.sql b/auth_v2.169.0/migrations/20221020193600_add_sessions_user_id_index.up.sql new file mode 100644 index 0000000..f5ba042 --- /dev/null +++ b/auth_v2.169.0/migrations/20221020193600_add_sessions_user_id_index.up.sql @@ -0,0 +1,2 @@ +create index if not exists sessions_user_id_idx on {{ index .Options "Namespace" }}.sessions (user_id); + diff --git a/auth_v2.169.0/migrations/20221021073300_add_refresh_tokens_session_id_revoked_index.up.sql b/auth_v2.169.0/migrations/20221021073300_add_refresh_tokens_session_id_revoked_index.up.sql new file mode 100644 index 0000000..0c47d4a --- /dev/null +++ b/auth_v2.169.0/migrations/20221021073300_add_refresh_tokens_session_id_revoked_index.up.sql @@ -0,0 +1 @@ +create index if not exists refresh_tokens_session_id_revoked_idx on {{ index .Options "Namespace" }}.refresh_tokens (session_id, revoked); diff --git a/auth_v2.169.0/migrations/20221021082433_add_saml.up.sql b/auth_v2.169.0/migrations/20221021082433_add_saml.up.sql new file mode 100644 index 0000000..30ac3d0 --- /dev/null +++ b/auth_v2.169.0/migrations/20221021082433_add_saml.up.sql @@ -0,0 +1,90 @@ +-- Multi-instance mode (see auth.instances) table intentionally not supported and ignored. + +create table if not exists {{ index .Options "Namespace" }}.sso_providers ( + id uuid not null, + resource_id text null, + created_at timestamptz null, + updated_at timestamptz null, + primary key (id), + constraint "resource_id not empty" check (resource_id = null or char_length(resource_id) > 0) +); + +comment on table {{ index .Options "Namespace" }}.sso_providers is 'Auth: Manages SSO identity provider information; see saml_providers for SAML.'; +comment on column {{ index .Options "Namespace" }}.sso_providers.resource_id is 'Auth: Uniquely identifies a SSO provider according to a user-chosen resource ID (case insensitive), useful in infrastructure as code.'; + +create unique index if not exists sso_providers_resource_id_idx on {{ index .Options "Namespace" }}.sso_providers (lower(resource_id)); + +create table if not exists {{ index .Options "Namespace" }}.sso_domains ( + id uuid not null, + sso_provider_id uuid not null, + domain text not null, + created_at timestamptz null, + updated_at timestamptz null, + primary key (id), + foreign key (sso_provider_id) references {{ index .Options "Namespace" }}.sso_providers (id) on delete cascade, + constraint "domain not empty" check (char_length(domain) > 0) +); + +create index if not exists sso_domains_sso_provider_id_idx on {{ index .Options "Namespace" }}.sso_domains (sso_provider_id); +create unique index if not exists sso_domains_domain_idx on {{ index .Options "Namespace" }}.sso_domains (lower(domain)); + +comment on table {{ index .Options "Namespace" }}.sso_domains is 'Auth: Manages SSO email address domain mapping to an SSO Identity Provider.'; + +create table if not exists {{ index .Options "Namespace" }}.saml_providers ( + id uuid not null, + sso_provider_id uuid not null, + entity_id text not null unique, + metadata_xml text not null, + metadata_url text null, + attribute_mapping jsonb null, + created_at timestamptz null, + updated_at timestamptz null, + primary key (id), + foreign key (sso_provider_id) references {{ index .Options "Namespace" }}.sso_providers (id) on delete cascade, + constraint "metadata_xml not empty" check (char_length(metadata_xml) > 0), + constraint "metadata_url not empty" check (metadata_url = null or char_length(metadata_url) > 0), + constraint "entity_id not empty" check (char_length(entity_id) > 0) +); + +create index if not exists saml_providers_sso_provider_id_idx on {{ index .Options "Namespace" }}.saml_providers (sso_provider_id); + +comment on table {{ index .Options "Namespace" }}.saml_providers is 'Auth: Manages SAML Identity Provider connections.'; + +create table if not exists {{ index .Options "Namespace" }}.saml_relay_states ( + id uuid not null, + sso_provider_id uuid not null, + request_id text not null, + for_email text null, + redirect_to text null, + from_ip_address inet null, + created_at timestamptz null, + updated_at timestamptz null, + primary key (id), + foreign key (sso_provider_id) references {{ index .Options "Namespace" }}.sso_providers (id) on delete cascade, + constraint "request_id not empty" check(char_length(request_id) > 0) +); + +create index if not exists saml_relay_states_sso_provider_id_idx on {{ index .Options "Namespace" }}.saml_relay_states (sso_provider_id); +create index if not exists saml_relay_states_for_email_idx on {{ index .Options "Namespace" }}.saml_relay_states (for_email); + +comment on table {{ index .Options "Namespace" }}.saml_relay_states is 'Auth: Contains SAML Relay State information for each Service Provider initiated login.'; + +create table if not exists {{ index .Options "Namespace" }}.sso_sessions ( + id uuid not null, + session_id uuid not null, + sso_provider_id uuid null, + not_before timestamptz null, + not_after timestamptz null, + idp_initiated boolean default false, + created_at timestamptz null, + updated_at timestamptz null, + primary key (id), + foreign key (session_id) references {{ index .Options "Namespace" }}.sessions (id) on delete cascade, + foreign key (sso_provider_id) references {{ index .Options "Namespace" }}.sso_providers (id) on delete cascade +); + +create index if not exists sso_sessions_session_id_idx on {{ index .Options "Namespace" }}.sso_sessions (session_id); +create index if not exists sso_sessions_sso_provider_id_idx on {{ index .Options "Namespace" }}.sso_sessions (sso_provider_id); + +comment on table {{ index .Options "Namespace" }}.sso_sessions is 'Auth: A session initiated by an SSO Identity Provider'; + diff --git a/auth_v2.169.0/migrations/20221027105023_add_identities_user_id_idx.up.sql b/auth_v2.169.0/migrations/20221027105023_add_identities_user_id_idx.up.sql new file mode 100644 index 0000000..12e7aa5 --- /dev/null +++ b/auth_v2.169.0/migrations/20221027105023_add_identities_user_id_idx.up.sql @@ -0,0 +1 @@ +create index if not exists identities_user_id_idx on {{ index .Options "Namespace" }}.identities using btree (user_id); diff --git a/auth_v2.169.0/migrations/20221114143122_add_session_not_after_column.up.sql b/auth_v2.169.0/migrations/20221114143122_add_session_not_after_column.up.sql new file mode 100644 index 0000000..c729911 --- /dev/null +++ b/auth_v2.169.0/migrations/20221114143122_add_session_not_after_column.up.sql @@ -0,0 +1,4 @@ +alter table only {{ index .Options "Namespace" }}.sessions + add column if not exists not_after timestamptz; + +comment on column {{ index .Options "Namespace" }}.sessions.not_after is 'Auth: Not after is a nullable column that contains a timestamp after which the session should be regarded as expired.'; diff --git a/auth_v2.169.0/migrations/20221114143410_remove_parent_foreign_key_refresh_tokens.up.sql b/auth_v2.169.0/migrations/20221114143410_remove_parent_foreign_key_refresh_tokens.up.sql new file mode 100644 index 0000000..62d2078 --- /dev/null +++ b/auth_v2.169.0/migrations/20221114143410_remove_parent_foreign_key_refresh_tokens.up.sql @@ -0,0 +1,2 @@ +alter table only {{ index .Options "Namespace" }}.refresh_tokens + drop constraint refresh_tokens_parent_fkey; diff --git a/auth_v2.169.0/migrations/20221125140132_backfill_email_identity.up.sql b/auth_v2.169.0/migrations/20221125140132_backfill_email_identity.up.sql new file mode 100644 index 0000000..cd06425 --- /dev/null +++ b/auth_v2.169.0/migrations/20221125140132_backfill_email_identity.up.sql @@ -0,0 +1,11 @@ +-- backfill the auth.identities column by adding an email identity +-- for all auth.users with an email and password + +do $$ +begin + insert into {{ index .Options "Namespace" }}.identities (id, user_id, identity_data, provider, last_sign_in_at, created_at, updated_at) + select id, id as user_id, jsonb_build_object('sub', id, 'email', email) as identity_data, 'email' as provider, null as last_sign_in_at, '2022-11-25' as created_at, '2022-11-25' as updated_at + from {{ index .Options "Namespace" }}.users as users + where encrypted_password != '' and email is not null and not exists(select user_id from {{ index .Options "Namespace" }}.identities where user_id = users.id); +end; +$$; diff --git a/auth_v2.169.0/migrations/20221208132122_backfill_email_last_sign_in_at.up.sql b/auth_v2.169.0/migrations/20221208132122_backfill_email_last_sign_in_at.up.sql new file mode 100644 index 0000000..19ec79e --- /dev/null +++ b/auth_v2.169.0/migrations/20221208132122_backfill_email_last_sign_in_at.up.sql @@ -0,0 +1,13 @@ +-- previous backfill migration left last_sign_in_at to be null, which broke some projects + +do $$ +begin +update {{ index .Options "Namespace" }}.identities + set last_sign_in_at = '2022-11-25' + where + last_sign_in_at is null and + created_at = '2022-11-25' and + updated_at = '2022-11-25' and + provider = 'email' and + id = user_id::text; +end $$; diff --git a/auth_v2.169.0/migrations/20221215195500_modify_users_email_unique_index.up.sql b/auth_v2.169.0/migrations/20221215195500_modify_users_email_unique_index.up.sql new file mode 100644 index 0000000..c12de04 --- /dev/null +++ b/auth_v2.169.0/migrations/20221215195500_modify_users_email_unique_index.up.sql @@ -0,0 +1,23 @@ +-- this change is relatively temporary +-- it is meant to keep database consistency guarantees until there is proper +-- introduction of account linking / merging / delinking APIs, at which point +-- rows in the users table will allow duplicates but with programmatic control + +alter table only {{ index .Options "Namespace" }}.users + add column if not exists is_sso_user boolean not null default false; + +comment on column {{ index .Options "Namespace" }}.users.is_sso_user is 'Auth: Set this column to true when the account comes from SSO. These accounts can have duplicate emails.'; + +do $$ +begin + alter table only {{ index .Options "Namespace" }}.users + drop constraint if exists users_email_key; +exception +-- dependent object: https://www.postgresql.org/docs/current/errcodes-appendix.html +when SQLSTATE '2BP01' then + raise notice 'Unable to drop users_email_key constraint due to dependent objects, please resolve this manually or SSO may not work'; +end $$; + +create unique index if not exists users_email_partial_key on {{ index .Options "Namespace" }}.users (email) where (is_sso_user = false); + +comment on index {{ index .Options "Namespace" }}.users_email_partial_key is 'Auth: A partial unique index that applies only when is_sso_user is false'; diff --git a/auth_v2.169.0/migrations/20221215195800_add_identities_email_column.up.sql b/auth_v2.169.0/migrations/20221215195800_add_identities_email_column.up.sql new file mode 100644 index 0000000..eb60334 --- /dev/null +++ b/auth_v2.169.0/migrations/20221215195800_add_identities_email_column.up.sql @@ -0,0 +1,18 @@ +do $$ +begin + update + {{ index .Options "Namespace" }}.identities as identities + set + identity_data = identity_data || jsonb_build_object('email', (select email from {{ index .Options "Namespace" }}.users where id = identities.user_id)), + updated_at = '2022-11-25' + where identities.provider = 'email' and identity_data->>'email' is null; +end $$; + +alter table only {{ index .Options "Namespace" }}.identities + add column if not exists email text generated always as (lower(identity_data->>'email')) stored; + +comment on column {{ index .Options "Namespace" }}.identities.email is 'Auth: Email is a generated column that references the optional email property in the identity_data'; + +create index if not exists identities_email_idx on {{ index .Options "Namespace" }}.identities (email text_pattern_ops); + +comment on index {{ index .Options "Namespace" }}.identities_email_idx is 'Auth: Ensures indexed queries on the email column'; diff --git a/auth_v2.169.0/migrations/20221215195900_remove_sso_sessions.up.sql b/auth_v2.169.0/migrations/20221215195900_remove_sso_sessions.up.sql new file mode 100644 index 0000000..228302d --- /dev/null +++ b/auth_v2.169.0/migrations/20221215195900_remove_sso_sessions.up.sql @@ -0,0 +1,3 @@ +-- sso_sessions is not used as all of the necessary data is in sessions +drop table if exists {{ index .Options "Namespace" }}.sso_sessions; + diff --git a/auth_v2.169.0/migrations/20230116124310_alter_phone_type.up.sql b/auth_v2.169.0/migrations/20230116124310_alter_phone_type.up.sql new file mode 100644 index 0000000..fa846db --- /dev/null +++ b/auth_v2.169.0/migrations/20230116124310_alter_phone_type.up.sql @@ -0,0 +1,14 @@ +-- alter phone field column type to accomodate for soft deletion + +do $$ +begin + alter table {{ index .Options "Namespace" }}.users + alter column phone type text, + alter column phone_change type text; +exception + -- SQLSTATE errcodes https://www.postgresql.org/docs/current/errcodes-appendix.html + when SQLSTATE '0A000' then + raise notice 'Unable to change data type of phone, phone_change columns due to use by a view or rule'; + when SQLSTATE '2BP01' then + raise notice 'Unable to change data type of phone, phone_change columns due to dependent objects'; +end $$; diff --git a/auth_v2.169.0/migrations/20230116124412_add_deleted_at.up.sql b/auth_v2.169.0/migrations/20230116124412_add_deleted_at.up.sql new file mode 100644 index 0000000..999abaa --- /dev/null +++ b/auth_v2.169.0/migrations/20230116124412_add_deleted_at.up.sql @@ -0,0 +1,4 @@ +-- adds deleted_at column to auth.users + +alter table {{ index .Options "Namespace" }}.users +add column if not exists deleted_at timestamptz null; diff --git a/auth_v2.169.0/migrations/20230131181311_backfill_invite_identities.up.sql b/auth_v2.169.0/migrations/20230131181311_backfill_invite_identities.up.sql new file mode 100644 index 0000000..2fcb358 --- /dev/null +++ b/auth_v2.169.0/migrations/20230131181311_backfill_invite_identities.up.sql @@ -0,0 +1,9 @@ +-- backfills the missing email identity for invited users + +do $$ +begin + insert into {{ index .Options "Namespace" }}.identities (id, user_id, identity_data, provider, last_sign_in_at, created_at, updated_at) + select id, id as user_id, jsonb_build_object('sub', id, 'email', email) as identity_data, 'email' as provider, null as last_sign_in_at, '2023-01-25' as created_at, '2023-01-25' as updated_at + from {{ index .Options "Namespace" }}.users as users + where invited_at is not null and not exists (select user_id from {{ index .Options "Namespace" }}.identities where user_id = users.id and provider = 'email'); +end $$; diff --git a/auth_v2.169.0/migrations/20230322519590_add_flow_state_table.up.sql b/auth_v2.169.0/migrations/20230322519590_add_flow_state_table.up.sql new file mode 100644 index 0000000..a8842e5 --- /dev/null +++ b/auth_v2.169.0/migrations/20230322519590_add_flow_state_table.up.sql @@ -0,0 +1,20 @@ +-- see: https://stackoverflow.com/questions/7624919/check-if-a-user-defined-type-already-exists-in-postgresql/48382296#48382296 +do $$ begin + create type code_challenge_method as enum('s256', 'plain'); +exception + when duplicate_object then null; +end $$; +create table if not exists {{ index .Options "Namespace" }}.flow_state( + id uuid primary key, + user_id uuid null, + auth_code text not null, + code_challenge_method code_challenge_method not null, + code_challenge text not null, + provider_type text not null, + provider_access_token text null, + provider_refresh_token text null, + created_at timestamptz null, + updated_at timestamptz null +); +create index if not exists idx_auth_code on {{ index .Options "Namespace" }}.flow_state(auth_code); +comment on table {{ index .Options "Namespace" }}.flow_state is 'stores metadata for pkce logins'; diff --git a/auth_v2.169.0/migrations/20230402418590_add_authentication_method_to_flow_state_table.up.sql b/auth_v2.169.0/migrations/20230402418590_add_authentication_method_to_flow_state_table.up.sql new file mode 100644 index 0000000..e83af85 --- /dev/null +++ b/auth_v2.169.0/migrations/20230402418590_add_authentication_method_to_flow_state_table.up.sql @@ -0,0 +1,6 @@ +alter table {{index .Options "Namespace" }}.flow_state +add column if not exists authentication_method text not null; +create index if not exists idx_user_id_auth_method on {{index .Options "Namespace" }}.flow_state (user_id, authentication_method); + +-- Update comment as we have generalized the table +comment on table {{ index .Options "Namespace" }}.flow_state is 'stores metadata for pkce logins'; diff --git a/auth_v2.169.0/migrations/20230411005111_remove_duplicate_idx.up.sql b/auth_v2.169.0/migrations/20230411005111_remove_duplicate_idx.up.sql new file mode 100644 index 0000000..dc23931 --- /dev/null +++ b/auth_v2.169.0/migrations/20230411005111_remove_duplicate_idx.up.sql @@ -0,0 +1 @@ +drop index if exists {{index .Options "Namespace" }}.refresh_tokens_token_idx; diff --git a/auth_v2.169.0/migrations/20230508135423_add_cleanup_indexes.up.sql b/auth_v2.169.0/migrations/20230508135423_add_cleanup_indexes.up.sql new file mode 100644 index 0000000..162acee --- /dev/null +++ b/auth_v2.169.0/migrations/20230508135423_add_cleanup_indexes.up.sql @@ -0,0 +1,17 @@ +-- Indexes used for cleaning up old or stale objects. + +create index if not exists + refresh_tokens_updated_at_idx + on {{ index .Options "Namespace" }}.refresh_tokens (updated_at desc); + +create index if not exists + flow_state_created_at_idx + on {{ index .Options "Namespace" }}.flow_state (created_at desc); + +create index if not exists + saml_relay_states_created_at_idx + on {{ index .Options "Namespace" }}.saml_relay_states (created_at desc); + +create index if not exists + sessions_not_after_idx + on {{ index .Options "Namespace" }}.sessions (not_after desc); diff --git a/auth_v2.169.0/migrations/20230523124323_add_mfa_challenge_cleanup_index.up.sql b/auth_v2.169.0/migrations/20230523124323_add_mfa_challenge_cleanup_index.up.sql new file mode 100644 index 0000000..667d502 --- /dev/null +++ b/auth_v2.169.0/migrations/20230523124323_add_mfa_challenge_cleanup_index.up.sql @@ -0,0 +1,5 @@ +-- Index used to clean up mfa challenges + +create index if not exists + mfa_challenge_created_at_idx + on {{ index .Options "Namespace" }}.mfa_challenges (created_at desc); diff --git a/auth_v2.169.0/migrations/20230818113222_add_flow_state_to_relay_state.up.sql b/auth_v2.169.0/migrations/20230818113222_add_flow_state_to_relay_state.up.sql new file mode 100644 index 0000000..f940e70 --- /dev/null +++ b/auth_v2.169.0/migrations/20230818113222_add_flow_state_to_relay_state.up.sql @@ -0,0 +1 @@ +alter table {{ index .Options "Namespace" }}.saml_relay_states add column if not exists flow_state_id uuid references {{ index .Options "Namespace" }}.flow_state(id) on delete cascade default null; diff --git a/auth_v2.169.0/migrations/20230914180801_add_mfa_factors_user_id_idx.up.sql b/auth_v2.169.0/migrations/20230914180801_add_mfa_factors_user_id_idx.up.sql new file mode 100644 index 0000000..805c97c --- /dev/null +++ b/auth_v2.169.0/migrations/20230914180801_add_mfa_factors_user_id_idx.up.sql @@ -0,0 +1 @@ +create index if not exists mfa_factors_user_id_idx on {{ index .Options "Namespace" }}.mfa_factors(user_id); diff --git a/auth_v2.169.0/migrations/20231027141322_add_session_refresh_columns.up.sql b/auth_v2.169.0/migrations/20231027141322_add_session_refresh_columns.up.sql new file mode 100644 index 0000000..79efba9 --- /dev/null +++ b/auth_v2.169.0/migrations/20231027141322_add_session_refresh_columns.up.sql @@ -0,0 +1,4 @@ +alter table if exists {{ index .Options "Namespace" }}.sessions + add column if not exists refreshed_at timestamp without time zone, + add column if not exists user_agent text, + add column if not exists ip inet; diff --git a/auth_v2.169.0/migrations/20231114161723_add_sessions_tag.up.sql b/auth_v2.169.0/migrations/20231114161723_add_sessions_tag.up.sql new file mode 100644 index 0000000..7acf1bb --- /dev/null +++ b/auth_v2.169.0/migrations/20231114161723_add_sessions_tag.up.sql @@ -0,0 +1,2 @@ +alter table if exists {{ index .Options "Namespace" }}.sessions + add column if not exists tag text; diff --git a/auth_v2.169.0/migrations/20231117164230_add_id_pkey_identities.up.sql b/auth_v2.169.0/migrations/20231117164230_add_id_pkey_identities.up.sql new file mode 100644 index 0000000..31ed280 --- /dev/null +++ b/auth_v2.169.0/migrations/20231117164230_add_id_pkey_identities.up.sql @@ -0,0 +1,29 @@ +do $$ +begin + if not exists(select * + from information_schema.columns + where table_schema = '{{ index .Options "Namespace" }}' and table_name='identities' and column_name='provider_id') + then + alter table if exists {{ index .Options "Namespace" }}.identities + rename column id to provider_id; + end if; +end$$; + +alter table if exists {{ index .Options "Namespace" }}.identities + drop constraint if exists identities_pkey, + add column if not exists id uuid default gen_random_uuid() primary key; + +do $$ +begin + if not exists + (select constraint_name + from information_schema.table_constraints + where table_schema = '{{ index .Options "Namespace" }}' + and table_name = 'identities' + and constraint_name = 'identities_provider_id_provider_unique') + then + alter table if exists {{ index .Options "Namespace" }}.identities + add constraint identities_provider_id_provider_unique + unique(provider_id, provider); + end if; +end $$; diff --git a/auth_v2.169.0/migrations/20240115144230_remove_ip_address_from_saml_relay_state.up.sql b/auth_v2.169.0/migrations/20240115144230_remove_ip_address_from_saml_relay_state.up.sql new file mode 100644 index 0000000..169ec37 --- /dev/null +++ b/auth_v2.169.0/migrations/20240115144230_remove_ip_address_from_saml_relay_state.up.sql @@ -0,0 +1,7 @@ +do $$ +begin + if exists (select from information_schema.columns where table_schema = '{{ index .Options "Namespace" }}' and table_name = 'saml_relay_states' and column_name = 'from_ip_address') then + alter table {{ index .Options "Namespace" }}.saml_relay_states drop column from_ip_address; + end if; +end +$$; diff --git a/auth_v2.169.0/migrations/20240214120130_add_is_anonymous_column.up.sql b/auth_v2.169.0/migrations/20240214120130_add_is_anonymous_column.up.sql new file mode 100644 index 0000000..6ef963f --- /dev/null +++ b/auth_v2.169.0/migrations/20240214120130_add_is_anonymous_column.up.sql @@ -0,0 +1,8 @@ +do $$ +begin + alter table {{ index .Options "Namespace" }}.users + add column if not exists is_anonymous boolean not null default false; + + create index if not exists users_is_anonymous_idx on {{ index .Options "Namespace" }}.users using btree (is_anonymous); +end +$$; diff --git a/auth_v2.169.0/migrations/20240306115329_add_issued_at_to_flow_state.up.sql b/auth_v2.169.0/migrations/20240306115329_add_issued_at_to_flow_state.up.sql new file mode 100644 index 0000000..d6eff15 --- /dev/null +++ b/auth_v2.169.0/migrations/20240306115329_add_issued_at_to_flow_state.up.sql @@ -0,0 +1,3 @@ +do $$ begin +alter table {{ index .Options "Namespace" }}.flow_state add column if not exists auth_code_issued_at timestamptz null; +end $$ diff --git a/auth_v2.169.0/migrations/20240314092811_add_saml_name_id_format.up.sql b/auth_v2.169.0/migrations/20240314092811_add_saml_name_id_format.up.sql new file mode 100644 index 0000000..0196250 --- /dev/null +++ b/auth_v2.169.0/migrations/20240314092811_add_saml_name_id_format.up.sql @@ -0,0 +1,3 @@ +do $$ begin +alter table {{ index .Options "Namespace" }}.saml_providers add column if not exists name_id_format text null; +end $$ diff --git a/auth_v2.169.0/migrations/20240427152123_add_one_time_tokens_table.up.sql b/auth_v2.169.0/migrations/20240427152123_add_one_time_tokens_table.up.sql new file mode 100644 index 0000000..be73126 --- /dev/null +++ b/auth_v2.169.0/migrations/20240427152123_add_one_time_tokens_table.up.sql @@ -0,0 +1,37 @@ +do $$ begin + create type one_time_token_type as enum ( + 'confirmation_token', + 'reauthentication_token', + 'recovery_token', + 'email_change_token_new', + 'email_change_token_current', + 'phone_change_token' + ); +exception + when duplicate_object then null; +end $$; + + +do $$ begin + create table if not exists {{ index .Options "Namespace" }}.one_time_tokens ( + id uuid primary key, + user_id uuid not null references {{ index .Options "Namespace" }}.users on delete cascade, + token_type one_time_token_type not null, + token_hash text not null, + relates_to text not null, + created_at timestamp without time zone not null default now(), + updated_at timestamp without time zone not null default now(), + check (char_length(token_hash) > 0) + ); + + begin + create index if not exists one_time_tokens_token_hash_hash_idx on {{ index .Options "Namespace" }}.one_time_tokens using hash (token_hash); + create index if not exists one_time_tokens_relates_to_hash_idx on {{ index .Options "Namespace" }}.one_time_tokens using hash (relates_to); + exception when others then + -- Fallback to btree indexes if hash creation fails + create index if not exists one_time_tokens_token_hash_hash_idx on {{ index .Options "Namespace" }}.one_time_tokens using btree (token_hash); + create index if not exists one_time_tokens_relates_to_hash_idx on {{ index .Options "Namespace" }}.one_time_tokens using btree (relates_to); + end; + + create unique index if not exists one_time_tokens_user_id_token_type_key on {{ index .Options "Namespace" }}.one_time_tokens (user_id, token_type); +end $$; diff --git a/auth_v2.169.0/migrations/20240612123726_enable_rls_update_grants.up.sql b/auth_v2.169.0/migrations/20240612123726_enable_rls_update_grants.up.sql new file mode 100644 index 0000000..9201e84 --- /dev/null +++ b/auth_v2.169.0/migrations/20240612123726_enable_rls_update_grants.up.sql @@ -0,0 +1,36 @@ +do $$ begin + -- enable RLS policy on auth tables + alter table {{ index .Options "Namespace" }}.schema_migrations enable row level security; + alter table {{ index .Options "Namespace" }}.instances enable row level security; + alter table {{ index .Options "Namespace" }}.users enable row level security; + alter table {{ index .Options "Namespace" }}.audit_log_entries enable row level security; + alter table {{ index .Options "Namespace" }}.saml_relay_states enable row level security; + alter table {{ index .Options "Namespace" }}.refresh_tokens enable row level security; + alter table {{ index .Options "Namespace" }}.mfa_factors enable row level security; + alter table {{ index .Options "Namespace" }}.sessions enable row level security; + alter table {{ index .Options "Namespace" }}.sso_providers enable row level security; + alter table {{ index .Options "Namespace" }}.sso_domains enable row level security; + alter table {{ index .Options "Namespace" }}.mfa_challenges enable row level security; + alter table {{ index .Options "Namespace" }}.mfa_amr_claims enable row level security; + alter table {{ index .Options "Namespace" }}.saml_providers enable row level security; + alter table {{ index .Options "Namespace" }}.flow_state enable row level security; + alter table {{ index .Options "Namespace" }}.identities enable row level security; + alter table {{ index .Options "Namespace" }}.one_time_tokens enable row level security; + -- allow postgres role to select from auth tables and allow it to grant select to other roles + grant select on {{ index .Options "Namespace" }}.schema_migrations to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.instances to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.users to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.audit_log_entries to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.saml_relay_states to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.refresh_tokens to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.mfa_factors to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.sessions to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.sso_providers to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.sso_domains to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.mfa_challenges to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.mfa_amr_claims to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.saml_providers to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.flow_state to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.identities to postgres with grant option; + grant select on {{ index .Options "Namespace" }}.one_time_tokens to postgres with grant option; +end $$; diff --git a/auth_v2.169.0/migrations/20240729123726_add_mfa_phone_config.up.sql b/auth_v2.169.0/migrations/20240729123726_add_mfa_phone_config.up.sql new file mode 100644 index 0000000..ec94d7b --- /dev/null +++ b/auth_v2.169.0/migrations/20240729123726_add_mfa_phone_config.up.sql @@ -0,0 +1,12 @@ +do $$ begin + alter type {{ index .Options "Namespace" }}.factor_type add value 'phone'; +exception + when duplicate_object then null; +end $$; + + +alter table {{ index .Options "Namespace" }}.mfa_factors add column if not exists phone text unique default null; +alter table {{ index .Options "Namespace" }}.mfa_challenges add column if not exists otp_code text null; + + +create unique index if not exists unique_verified_phone_factor on {{ index .Options "Namespace" }}.mfa_factors (user_id, phone); diff --git a/auth_v2.169.0/migrations/20240802193726_add_mfa_factors_column_last_challenged_at.up.sql b/auth_v2.169.0/migrations/20240802193726_add_mfa_factors_column_last_challenged_at.up.sql new file mode 100644 index 0000000..bc3eea9 --- /dev/null +++ b/auth_v2.169.0/migrations/20240802193726_add_mfa_factors_column_last_challenged_at.up.sql @@ -0,0 +1 @@ +alter table {{ index .Options "Namespace" }}.mfa_factors add column if not exists last_challenged_at timestamptz unique default null; diff --git a/auth_v2.169.0/migrations/20240806073726_drop_uniqueness_constraint_on_phone.up.sql b/auth_v2.169.0/migrations/20240806073726_drop_uniqueness_constraint_on_phone.up.sql new file mode 100644 index 0000000..ade27ea --- /dev/null +++ b/auth_v2.169.0/migrations/20240806073726_drop_uniqueness_constraint_on_phone.up.sql @@ -0,0 +1,22 @@ +alter table {{ index .Options "Namespace" }}.mfa_factors drop constraint if exists mfa_factors_phone_key; +do $$ +begin + -- if both indexes exist, it means that the schema_migrations table was truncated and the migrations had to be rerun + if ( + select count(*) = 2 + from pg_indexes + where indexname in ('unique_verified_phone_factor', 'unique_phone_factor_per_user') + and schemaname = '{{ index .Options "Namespace" }}' + ) then + execute 'drop index {{ index .Options "Namespace" }}.unique_verified_phone_factor'; + end if; + + if exists ( + select 1 + from pg_indexes + where indexname = 'unique_verified_phone_factor' + and schemaname = '{{ index .Options "Namespace" }}' + ) then + execute 'alter index {{ index .Options "Namespace" }}.unique_verified_phone_factor rename to unique_phone_factor_per_user'; + end if; +end $$; diff --git a/auth_v2.169.0/migrations/20241009103726_add_web_authn.up.sql b/auth_v2.169.0/migrations/20241009103726_add_web_authn.up.sql new file mode 100644 index 0000000..04d8972 --- /dev/null +++ b/auth_v2.169.0/migrations/20241009103726_add_web_authn.up.sql @@ -0,0 +1,3 @@ +alter table {{ index .Options "Namespace" }}.mfa_factors add column if not exists web_authn_credential jsonb null; +alter table {{ index .Options "Namespace" }}.mfa_factors add column if not exists web_authn_aaguid uuid null; +alter table {{ index .Options "Namespace" }}.mfa_challenges add column if not exists web_authn_session_data jsonb null; diff --git a/auth_v2.169.0/openapi.yaml b/auth_v2.169.0/openapi.yaml new file mode 100644 index 0000000..1f52436 --- /dev/null +++ b/auth_v2.169.0/openapi.yaml @@ -0,0 +1,2349 @@ +openapi: 3.0.3 +info: + version: latest + title: GoTrue REST API (Supabase Auth) + description: |- + GoTrue is the software behind [Supabase Auth](https://supabase.com/auth). This is its REST API. + + **Notes:** + - HTTP 5XX errors are not listed for each endpoint. + These should be handled globally. Not all HTTP 5XX errors are generated from GoTrue, and they may serve non-JSON content. Make sure you inspect the `Content-Type` header before parsing as JSON. + - Error responses are somewhat inconsistent. + Avoid using the `msg` and HTTP status code to identify errors. HTTP 400 and 422 are used interchangeably in many APIs. + - If the server has CAPTCHA protection enabled, the verification token should be included in the request body. + - Rate limit errors are consistently raised with the HTTP 429 code. + - Enums are used only in request bodies / parameters and not in responses to ensure wide compatibility with code generators that fail to include an unknown enum case. + + **Backward compatibility:** + - Endpoints marked as _Experimental_ may change without notice. + - Endpoints marked as _Deprecated_ will be supported for at least 3 months since being marked as deprecated. + - HTTP status codes like 400, 404, 422 may change for the same underlying error condition. + + termsOfService: https://supabase.com/terms + contact: + name: Ask a question about this API + url: https://github.com/supabase/supabase/discussions + license: + name: MIT License + url: https://github.com/supabase/gotrue/blob/master/LICENSE +externalDocs: + description: Learn more about Supabase Auth + url: https://supabase.com/docs/guides/auth/overview +servers: + - url: "https://{project}.supabase.co/auth/v1" + variables: + project: + description: > + Your Supabase project ID. + default: abcdefghijklmnopqrst +tags: + - name: auth + description: APIs for authentication and authorization. + - name: user + description: APIs used by a user to manage their account. + - name: oauth + description: APIs for dealing with OAuth flows. + - name: oidc + description: APIs for dealing with OIDC authentication flows. (Experimental.) + - name: sso + description: APIs for authenticating using SSO providers (SAML). (Experimental.) + - name: saml + description: SAML 2.0 Endpoints. (Experimental.) + - name: admin + description: Administration APIs requiring elevated access. + - name: general + description: General APIs. +paths: + /token: + post: + summary: Issues access and refresh tokens based on grant type. + tags: + - auth + - oidc + parameters: + - name: grant_type + in: query + required: true + description: > + What grant type should be used to issue an access and refresh token. Note that `id_token` is only offered in experimental mode. CAPTCHA protection is not effective on the `refresh_token` grant flow. + schema: + type: string + enum: + - password + - refresh_token + - id_token + - pkce + security: + - APIKeyAuth: [] + requestBody: + content: + application/json: + examples: + grant_type=password: + value: + email: user@example.com + password: password1 + grant_type=refresh_token: + value: + refresh_token: 4nYUCw0wZR_DNOTSDbSGMQ + grant_type=pkce: + value: + auth_code: 009e5066-fc11-4eca-8c8c-6fd82aa263f2 + code_verifier: ktPNXpR65N6JtgzQA8_5HHtH6PBSAahMNoLKRzQEa0Tzgl.vdV~b6lPk004XOd.4lR0inCde.NoQx5K63xPfzL8o7tJAjXncnhw5Niv9ycQ.QRV9JG.y3VapqbgLfIrJ + schema: + type: object + description: |- + For the refresh token flow, supply only `refresh_token`. + For the email/phone with password flow, supply `email`, `phone` and `password` with an optional `gotrue_meta_security`. + For the OIDC ID token flow, supply `id_token`, `nonce`, `provider`, `client_id`, `issuer` with an optional `gotrue_meta_security`. + properties: + refresh_token: + type: string + password: + type: string + email: + type: string + format: email + phone: + type: string + format: phone + id_token: + type: string + access_token: + type: string + description: Provide only when `grant_type` is `id_token` and the provided ID token requires the presence of an access token to be accepted (usually by having an `at_hash` claim). + nonce: + type: string + provider: + type: string + enum: + - google + - apple + - azure + - facebook + - keycloak + client_id: + type: string + issuer: + type: string + description: If `provider` is `azure` then you can specify any Azure OIDC issuer string here, which will be used for verification. + gotrue_meta_security: + $ref: "#/components/schemas/GoTrueMetaSecurity" + auth_code: + type: string + format: uuid + code_verifier: + type: string + responses: + 200: + description: > + An access and refresh token have been successfully issued. + content: + application/json: + schema: + $ref: "#/components/schemas/AccessTokenResponseSchema" + + 400: + $ref: "#/components/responses/BadRequestResponse" + 401: + $ref: "#/components/responses/ForbiddenResponse" + 403: + $ref: "#/components/responses/UnauthorizedResponse" + 500: + $ref: "#/components/responses/InternalServerErrorResponse" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /logout: + post: + summary: Logs out a user. + tags: + - auth + security: + - APIKeyAuth: [] + UserAuth: [] + parameters: + - name: scope + in: query + description: > + (Optional.) Determines how the user should be logged out. When `global` is used, the user is logged out from all active sessions. When `local` is used, the user is logged out from the current session. When `others` is used, the user is logged out from all other sessions except the current one. Clients should remove stored access and refresh tokens except when `others` is used. + schema: + type: string + enum: + - global + - local + - others + responses: + 204: + description: No content returned on successful logout. + 401: + $ref: "#/components/responses/UnauthorizedResponse" + + /verify: + get: + summary: Authenticate by verifying the possession of a one-time token. Usually for use as clickable links. + tags: + - auth + parameters: + - name: token + in: query + required: true + schema: + type: string + - name: type + in: query + required: true + schema: + type: string + enum: + - signup + - invite + - recovery + - magiclink + - email_change + - name: redirect_to + in: query + description: > + (Optional) URL to redirect back into the app on after verification completes successfully. If not specified will use the "Site URL" configuration option. If not allowed per the allow list it will use the "Site URL" configuration option. + schema: + type: string + format: uri + security: + - APIKeyAuth: [] + responses: + 302: + $ref: "#/components/responses/AccessRefreshTokenRedirectResponse" + post: + summary: Authenticate by verifying the possession of a one-time token. + tags: + - auth + security: + - APIKeyAuth: [] + requestBody: + content: + application/json: + schema: + type: object + properties: + type: + type: string + enum: + - signup + - recovery + - invite + - magiclink + - email_change + - sms + - phone_change + token: + type: string + token_hash: + type: string + description: > + The hashed value of token. Applicable only if used with `type` and nothing else. + email: + type: string + format: email + description: > + Applicable only if `type` is with regards to an email address. + phone: + type: string + format: phone + description: > + Applicable only if `type` is with regards to an phone number. + redirect_to: + type: string + format: uri + description: > + (Optional) URL to redirect back into the app on after verification completes successfully. If not specified will use the "Site URL" configuration option. If not allowed per the allow list it will use the "Site URL" configuration option. + + responses: + 200: + description: An access and refresh token. + content: + application/json: + schema: + $ref: "#/components/schemas/AccessTokenResponseSchema" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /authorize: + get: + summary: Redirects to an external OAuth provider. Usually for use as clickable links. + tags: + - oauth + security: + - APIKeyAuth: [] + parameters: + - name: provider + in: query + description: Name of the OAuth provider. + example: google + required: true + schema: + type: string + pattern: "^[a-zA-Z0-9]+$" + - name: scopes + in: query + required: true + description: Space separated list of OAuth scopes to pass on to `provider`. + schema: + type: string + pattern: "[^ ]+( +[^ ]+)*" + - name: invite_token + in: query + description: (Optional) A token representing a previous invitation of the user. A successful sign-in with OAuth will mark the invitation as completed. + schema: + type: string + - name: redirect_to + in: query + description: > + (Optional) URL to redirect back into the app on after OAuth sign-in completes successfully or not. If not specified will use the "Site URL" configuration option. If not allowed per the allow list it will use the "Site URL" configuration option. + schema: + type: string + format: uri + - name: code_challenge_method + in: query + description: (Optional) Method used to encrypt the verifier. Can be `plain` (no transformation) or `s256` (where SHA-256 is used). It is always recommended that `s256` is used. + schema: + type: string + enum: + - plain + - s256 + responses: + 302: + $ref: "#/components/responses/OAuthAuthorizeRedirectResponse" + + /signup: + post: + summary: Signs a user up. + description: > + Creates a new user. + tags: + - auth + security: + - APIKeyAuth: [] + requestBody: + content: + application/json: + examples: + "email+password": + value: + email: user@example.com + password: password1 + "phone+password": + value: + phone: "+1234567890" + password: password1 + "phone+password+whatsapp": + value: + phone: "+1234567890" + password: password1 + channel: whatsapp + "email+password+pkce": + value: + email: user@example.com + password: password1 + code_challenge_method: s256 + code_challenge: elU6u5zyqQT2f92GRQUq6PautAeNDf4DQPayyR0ek_c& + schema: + type: object + properties: + email: + type: string + format: email + phone: + type: string + format: phone + channel: + type: string + enum: + - sms + - whatsapp + password: + type: string + data: + type: object + code_challenge: + type: string + code_challenge_method: + type: string + enum: + - plain + - s256 + gotrue_meta_security: + $ref: "#/components/schemas/GoTrueMetaSecurity" + responses: + 200: + description: > + A user already exists and is not confirmed (in which case a user object is returned). A user did not exist and is signed up. If email or phone confirmation is enabled, returns a user object. If confirmation is disabled, returns an access token and refresh token response. + content: + application/json: + schema: + oneOf: + - $ref: "#/components/schemas/AccessTokenResponseSchema" + - $ref: "#/components/schemas/UserSchema" + 400: + $ref: "#/components/responses/BadRequestResponse" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /recover: + post: + summary: Request password recovery. + description: > + Users that have forgotten their password can have it reset with this API. + tags: + - auth + security: + - APIKeyAuth: [] + requestBody: + content: + application/json: + schema: + type: object + required: + - email + properties: + email: + type: string + format: email + code_challenge: + type: string + code_challenge_method: + type: string + enum: + - plain + - s256 + gotrue_meta_security: + $ref: "#/components/schemas/GoTrueMetaSecurity" + responses: + 200: + description: A recovery email has been sent to the address. An empty JSON object is returned. To obfuscate whether such an email address already exists in the system this response is sent regardless whether the address exists or not. + content: + application/json: + schema: + type: object + 400: + $ref: "#/components/responses/BadRequestResponse" + 422: + description: Returned when unable to validate the email address. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /resend: + post: + summary: Resends a one-time password (OTP) through email or SMS. + description: > + Allows a user to resend an existing signup, sms, email_change or phone_change OTP. + tags: + - auth + security: + - APIKeyAuth: [] + requestBody: + content: + application/json: + schema: + type: object + properties: + email: + type: string + format: email + description: > + Applicable only if `type` is with regards to an email address. + phone: + type: string + format: phone + description: > + Applicable only if `type` is with regards to an phone number. + type: + type: string + enum: + - signup + - email_change + - sms + - phone_change + gotrue_meta_security: + $ref: "#/components/schemas/GoTrueMetaSecurity" + responses: + 200: + description: A One-Time Password was sent to the email or phone. To obfuscate whether such an address or number already exists in the system this response is sent in both cases. + content: + application/json: + schema: + type: object + properties: + message_id: + type: string + description: Unique ID of the message as reported by the SMS sending provider. Useful for tracking deliverability problems. + 400: + $ref: "#/components/responses/BadRequestResponse" + 422: + description: Returned when unable to validate the email address or phone number. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /magiclink: + post: + summary: Authenticate a user by sending them a magic link. + description: > + A magic link is a special type of URL that includes a One-Time Password. When a user visits this link in a browser they are immediately authenticated. + tags: + - auth + security: + - APIKeyAuth: [] + requestBody: + content: + application/json: + schema: + type: object + required: + - email + properties: + email: + type: string + format: email + data: + type: object + gotrue_meta_security: + $ref: "#/components/schemas/GoTrueMetaSecurity" + responses: + 200: + description: A recovery email has been sent to the address. An empty JSON object is returned. To obfuscate whether such an email address already exists in the system this response is sent regardless whether the address exists or not. + content: + application/json: + schema: + type: object + 400: + $ref: "#/components/responses/BadRequestResponse" + 422: + description: Returned when unable to validate the email address. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /otp: + post: + summary: Authenticate a user by sending them a One-Time Password over email or SMS. + tags: + - auth + security: + - APIKeyAuth: [] + requestBody: + content: + application/json: + schema: + type: object + properties: + email: + type: string + format: email + phone: + type: string + format: phone + channel: + type: string + enum: + - sms + - whatsapp + create_user: + type: boolean + data: + type: object + code_challenge_method: + type: string + enum: + - s256 + - plain + code_challenge: + type: string + gotrue_meta_security: + $ref: "#/components/schemas/GoTrueMetaSecurity" + responses: + 200: + description: A One-Time Password was sent to the email or phone. To obfuscate whether such an address or number already exists in the system this response is sent in both cases. + content: + application/json: + schema: + type: object + properties: + message_id: + type: string + description: Unique ID of the message as reported by the SMS sending provider. Useful for tracking deliverability problems. + 400: + $ref: "#/components/responses/BadRequestResponse" + 422: + description: Returned when unable to validate the email or phone number. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /user: + get: + summary: Fetch the latest user account information. + tags: + - user + security: + - APIKeyAuth: [] + UserAuth: [] + responses: + 200: + description: User's account information. + content: + application/json: + schema: + $ref: "#/components/schemas/UserSchema" + put: + summary: Update certain properties of the current user account. + tags: + - user + security: + - APIKeyAuth: [] + UserAuth: [] + requestBody: + content: + application/json: + schema: + type: object + properties: + email: + type: string + format: email + phone: + type: string + format: phone + password: + type: string + nonce: + type: string + data: + type: object + app_metadata: + type: object + channel: + type: string + enum: + - sms + - whatsapp + responses: + 200: + description: User's updated account information. + content: + application/json: + schema: + $ref: "#/components/schemas/UserSchema" + 400: + $ref: "#/components/responses/BadRequestResponse" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /reauthenticate: + post: + summary: Reauthenticates the possession of an email or phone number for the purpose of password change. + description: > + For a password to be changed on a user account, the user's email or phone number needs to be confirmed before they are allowed to set a new password. This requirement is configurable. This API sends a confirmation email or SMS message. A nonce in this message can be provided in `PUT /user` to change the password on the account. + tags: + - user + security: + - APIKeyAuth: [] + UserAuth: [] + responses: + 200: + description: A One-Time Password was sent to the user's email or phone. + content: + application/json: + schema: + type: object + 400: + $ref: "#/components/responses/BadRequestResponse" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /factors: + post: + summary: Begin enrolling a new factor for MFA. + tags: + - user + security: + - APIKeyAuth: [] + UserAuth: [] + requestBody: + content: + application/json: + schema: + type: object + required: + - factor_type + properties: + factor_type: + type: string + enum: + - totp + - phone + - webauthn + friendly_name: + type: string + issuer: + type: string + format: uri + phone: + type: string + format: phone + responses: + 200: + description: > + A new factor was created in the unverified state. Call `POST /factors/{factorId}/verify' to verify it. + content: + application/json: + schema: + type: object + properties: + id: + type: string + type: + type: string + enum: + - totp + - phone + - webauthn + totp: + type: object + properties: + qr_code: + type: string + secret: + type: string + uri: + type: string + phone: + type: string + format: phone + + 400: + $ref: "#/components/responses/BadRequestResponse" + + /factors/{factorId}/challenge: + post: + summary: Create a new challenge for a MFA factor. + tags: + - user + security: + - APIKeyAuth: [] + UserAuth: [] + parameters: + - name: factorId + in: path + required: true + example: 2b306a77-21dc-4110-ba71-537cb56b9e98 + schema: + type: string + format: uuid + requestBody: + content: + application/json: + schema: + type: object + properties: + channel: + type: string + enum: + - sms + - whatsapp + + responses: + 200: + description: > + A new challenge was generated for the factor. Use `POST /factors/{factorId}/verify` to verify the challenge. + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/TOTPPhoneChallengeResponse' + - $ref: '#/components/schemas/WebAuthnChallengeResponse' + 400: + $ref: "#/components/responses/BadRequestResponse" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /factors/{factorId}/verify: + post: + summary: Verify a challenge on a factor. + tags: + - user + security: + - APIKeyAuth: [] + UserAuth: [] + parameters: + - name: factorId + in: path + required: true + example: 2b306a77-21dc-4110-ba71-537cb56b9e98 + schema: + type: string + format: uuid + requestBody: + content: + application/json: + schema: + type: object + required: + - challenge_id + properties: + challenge_id: + type: string + format: uuid + code: + type: string + responses: + 200: + description: > + This challenge has been verified. Client libraries should replace their stored access and refresh tokens with the ones provided in this response. These new credentials have an increased Authenticator Assurance Level (AAL). + content: + application/json: + schema: + $ref: "#/components/schemas/AccessTokenResponseSchema" + 400: + $ref: "#/components/responses/BadRequestResponse" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /factors/{factorId}: + delete: + summary: Remove a MFA factor from a user. + tags: + - user + security: + - APIKeyAuth: [] + UserAuth: [] + parameters: + - name: factorId + in: path + required: true + example: 2b306a77-21dc-4110-ba71-537cb56b9e98 + schema: + type: string + format: uuid + responses: + 200: + description: > + This MFA factor is removed (unenrolled) and cannot be used for increasing the AAL level of user's sessions. Client libraries should use the `POST /token?grant_type=refresh_token` endpoint to get a new access and refresh token with a decreased AAL. + content: + application/json: + schema: + type: object + properties: + id: + type: string + format: uuid + example: 2b306a77-21dc-4110-ba71-537cb56b9e98 + 400: + $ref: "#/components/responses/BadRequestResponse" + + /callback: + get: + summary: Redirects OAuth flow errors to the frontend app. + description: > + When an OAuth sign-in flow fails for any reason, the error message needs to be delivered to the frontend app requesting the flow. This callback delivers the errors as `error` and `error_description` query params. Usually this request is not called directly. + tags: + - oauth + security: + - APIKeyAuth: [] + responses: + 302: + $ref: "#/components/responses/OAuthCallbackRedirectResponse" + post: + summary: Redirects OAuth flow errors to the frontend app. + description: > + When an OAuth sign-in flow fails for any reason, the error message needs to be delivered to the frontend app requesting the flow. This callback delivers the errors as `error` and `error_description` query params. Usually this request is not called directly. + tags: + - oauth + responses: + 302: + $ref: "#/components/responses/OAuthCallbackRedirectResponse" + + /sso: + post: + summary: Initiate a Single-Sign On flow. + tags: + - sso + security: + - APIKeyAuth: [] + requestBody: + content: + application/json: + schema: + type: object + properties: + domain: + type: string + format: hostname + description: Email address domain used to identify the SSO provider. + provider_id: + type: string + format: uuid + example: 40451fc2-4997-429c-bf7f-cc6f33c788e6 + redirect_to: + type: string + format: uri + skip_http_redirect: + type: boolean + description: Set to `true` if the response to this request should not be a HTTP 303 redirect -- useful for browser-based applications. + code_challenge: + type: string + code_challenge_method: + type: string + enum: + - plain + - s256 + gotrue_meta_security: + $ref: "#/components/schemas/GoTrueMetaSecurity" + responses: + 200: + description: > + Returned only when `skip_http_redirect` is `true` and the SSO provider could be identified from the `provider_id` or `domain`. Client libraries should use the returned URL to redirect or open a browser. + content: + application/json: + schema: + type: object + properties: + url: + type: string + format: uri + 303: + description: > + Returned only when `skip_http_redirect` is `false` or not present and the SSO provider could be identified from the `provider_id` or `domain`. Client libraries should follow the redirect. 303 is used instead of 302 because the request should be executed with a `GET` verb. + headers: + Location: + schema: + type: string + format: uri + 400: + $ref: "#/components/responses/BadRequestResponse" + 404: + description: > + Returned when the SSO provider could not be identified. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + + /saml/metadata: + get: + summary: Returns the SAML 2.0 Metadata XML. + description: > + The metadata XML can be downloaded or used for the SAML 2.0 Metadata URL discovery mechanism. This URL is the SAML 2.0 EntityID of the Service Provider implemented by this server. + tags: + - saml + security: + - APIKeyAuth: [] + parameters: + - name: download + in: query + description: > + If set to `true` will add a `Content-Disposition` header to the response which will trigger a download dialog on the browser. + schema: + type: boolean + responses: + 200: + description: > + A valid SAML 2.0 Metadata XML document. Should be cached according to the `Cache-Control` header and/or caching data specified in the document itself. + headers: + Content-Disposition: + description: > + Present if `download=true`, which triggers the browser to show a donwload dialog. + schema: + type: string + example: attachment; filename="metadata.xml" + Cache-Control: + description: > + Should be parsed and obeyed to avoid putting strain on the server. + schema: + type: string + example: public, max-age=600 + + /saml/acs: + post: + summary: SAML 2.0 Assertion Consumer Service (ACS) endpoint. + description: > + Implements the SAML 2.0 Assertion Consumer Service (ACS) endpoint supporting the POST and Artifact bindings. + tags: + - saml + security: [] + parameters: + - name: RelayState + in: query + schema: + oneOf: + - type: string + format: uri + description: URL to take the user to after the ACS has been verified. Often sent by Identity Provider initiated login requests. + - type: string + format: uuid + description: UUID of the SAML Relay State stored in the database, used to identify the Service Provider initiated login request. + - name: SAMLArt + in: query + description: > + See the SAML 2.0 ACS specification. Cannot be used without a UUID `RelayState` parameter. + schema: + type: string + - name: SAMLResponse + in: query + description: > + See the SAML 2.0 ACS specification. Must be present unless `SAMLArt` is specified. If `RelayState` is not a UUID, the SAML Response is unpacked and the identity provider is identified from the response. + schema: + type: string + responses: + 302: + $ref: "#/components/responses/AccessRefreshTokenRedirectResponse" + 400: + $ref: "#/components/responses/BadRequestResponse" + 429: + $ref: "#/components/responses/RateLimitResponse" + + /invite: + post: + summary: Invite a user by email. + description: > + Sends an invitation email which contains a link that allows the user to sign-in. + tags: + - admin + security: + - APIKeyAuth: [] + requestBody: + content: + application/json: + schema: + type: object + required: + - email + properties: + email: + type: string + data: + type: object + responses: + 200: + description: An invitation has been sent to the user. + content: + application/json: + schema: + $ref: "#/components/schemas/UserSchema" + 400: + $ref: "#/components/responses/BadRequestResponse" + 422: + description: User already exists and has confirmed their address. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + + /admin/generate_link: + post: + summary: Generate a link to send in an email message. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + requestBody: + content: + application/json: + schema: + type: object + required: + - type + - email + properties: + type: + type: string + enum: + - magiclink + - signup + - recovery + - email_change_current + - email_change_new + email: + type: string + format: email + new_email: + type: string + format: email + password: + type: string + data: + type: object + redirect_to: + type: string + format: uri + responses: + 200: + description: User profile and generated link information. + content: + application/json: + schema: + type: object + additionalProperties: true + properties: + action_link: + type: string + format: uri + email_otp: + type: string + hashed_token: + type: string + verification_type: + type: string + redirect_to: + type: string + format: uri + 400: + $ref: "#/components/responses/BadRequestResponse" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + 404: + description: There is no such user. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + 422: + description: > + Has multiple meanings: + - User already exists + - Provided password does not meet minimum criteria + - Secure email change not enabled + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + + /admin/audit: + get: + summary: Fetch audit log events. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + parameters: + - name: page + in: query + schema: + type: integer + minimum: 1 + default: 1 + - name: per_page + in: query + schema: + type: integer + minimum: 1 + default: 50 + responses: + 200: + description: List of audit logs. + content: + application/json: + schema: + type: array + items: + type: object + properties: + id: + type: string + format: uuid + payload: + type: object + properties: + actor_id: + type: string + actor_via_sso: + type: boolean + description: Whether the actor used a SSO protocol (like SAML 2.0 or OIDC) to authenticate. + actor_username: + type: string + actor_name: + type: string + traits: + type: object + action: + type: string + description: |- + Usually one of these values: + - login + - logout + - invite_accepted + - user_signedup + - user_invited + - user_deleted + - user_modified + - user_recovery_requested + - user_reauthenticate_requested + - user_confirmation_requested + - user_repeated_signup + - user_updated_password + - token_revoked + - token_refreshed + - generate_recovery_codes + - factor_in_progress + - factor_unenrolled + - challenge_created + - verification_attempted + - factor_deleted + - recovery_codes_deleted + - factor_updated + - mfa_code_login + log_type: + type: string + description: |- + Usually one of these values: + - account + - team + - token + - user + - factor + - recovery_codes + created_at: + type: string + format: date-time + ip_address: + type: string + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + + /admin/users: + get: + summary: Fetch a listing of users. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + parameters: + - name: page + in: query + schema: + type: integer + minimum: 1 + default: 1 + - name: per_page + in: query + schema: + type: integer + minimum: 1 + default: 50 + responses: + 200: + description: A page of users. + content: + application/json: + schema: + type: object + properties: + aud: + type: string + deprecated: true + users: + type: array + items: + $ref: "#/components/schemas/UserSchema" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + + /admin/users/{userId}: + parameters: + - name: userId + in: path + required: true + schema: + type: string + format: uuid + get: + summary: Fetch user account data for a user. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + responses: + 200: + description: User's account data. + content: + application/json: + schema: + $ref: "#/components/schemas/UserSchema" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + 404: + description: There is no such user. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + put: + summary: Update user's account data. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + requestBody: + content: + application/json: + schema: + $ref: "#/components/schemas/UserSchema" + responses: + 200: + description: User's account data was updated. + content: + application/json: + schema: + $ref: "#/components/schemas/UserSchema" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + 404: + description: There is no such user. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + delete: + summary: Delete a user. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + responses: + 200: + description: User's account data. + content: + application/json: + schema: + $ref: "#/components/schemas/UserSchema" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + 404: + description: There is no such user. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + + /admin/users/{userId}/factors: + parameters: + - name: userId + in: path + required: true + schema: + type: string + format: uuid + get: + summary: List all of the MFA factors for a user. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + responses: + 200: + description: User's MFA factors. + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/MFAFactorSchema" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + 404: + description: There is no such user. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + + /admin/users/{userId}/factors/{factorId}: + parameters: + - name: userId + in: path + required: true + schema: + type: string + format: uuid + - name: factorId + in: path + required: true + schema: + type: string + format: uuid + put: + summary: Update a user's MFA factor. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + requestBody: + content: + application/json: + schema: + type: object + responses: + 200: + description: User's MFA factor. + content: + application/json: + schema: + $ref: "#/components/schemas/MFAFactorSchema" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + 404: + description: There is no such user and/or factor. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + delete: + summary: Remove a user's MFA factor. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + responses: + 200: + description: User's MFA factor. + content: + application/json: + schema: + $ref: "#/components/schemas/MFAFactorSchema" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + 404: + description: There is no such user and/or factor. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + + /admin/sso/providers: + get: + summary: Fetch a list of all registered SSO providers. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + responses: + 200: + description: A list of all providers. + content: + application/json: + schema: + type: object + properties: + items: + type: array + items: + $ref: "#/components/schemas/SSOProviderSchema" + post: + summary: Register a new SSO provider. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + requestBody: + content: + application/json: + schema: + type: object + required: + - type + properties: + type: + type: string + enum: + - saml + metadata_url: + type: string + format: uri + metadata_xml: + type: string + domains: + type: array + items: + type: string + format: hostname + attribute_mapping: + $ref: "#/components/schemas/SAMLAttributeMappingSchema" + responses: + 200: + description: SSO provider was created. + content: + application/json: + schema: + $ref: "#/components/schemas/SSOProviderSchema" + 400: + $ref: "#/components/responses/BadRequestResponse" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + + /admin/sso/providers/{ssoProviderId}: + parameters: + - name: ssoProviderId + in: path + required: true + schema: + type: string + format: uuid + get: + summary: Fetch SSO provider details. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + responses: + 200: + description: SSO provider exists with these details. + content: + application/json: + schema: + $ref: "#/components/schemas/SSOProviderSchema" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + 404: + description: A provider with this UUID does not exist. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + put: + summary: Update details about a SSO provider. + description: > + You can only update only one of `metadata_url` or `metadata_xml` at once. The SAML Metadata represented by these updates must advertize the same Identity Provider EntityID. Do not include the `domains` or `attribute_mapping` property to keep the existing database values. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + requestBody: + content: + application/json: + schema: + type: object + properties: + metadata_url: + type: string + format: uri + metadata_xml: + type: string + domains: + type: array + items: + type: string + pattern: "[a-z0-9-]+([.][a-z0-9-]+)*" + attribute_mapping: + $ref: "#/components/schemas/SAMLAttributeMappingSchema" + responses: + 200: + description: SSO provider details were updated. + content: + application/json: + schema: + $ref: "#/components/schemas/SSOProviderSchema" + 400: + $ref: "#/components/responses/BadRequestResponse" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + 404: + description: A provider with this UUID does not exist. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + delete: + summary: Remove an SSO provider. + tags: + - admin + security: + - APIKeyAuth: [] + AdminAuth: [] + responses: + 200: + description: SSO provider was removed. + content: + application/json: + schema: + $ref: "#/components/schemas/SSOProviderSchema" + 401: + $ref: "#/components/responses/UnauthorizedResponse" + 403: + $ref: "#/components/responses/ForbiddenResponse" + 404: + description: A provider with this UUID does not exist. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + + /health: + get: + summary: Service healthcheck. + description: Ping this endpoint to receive information about the health of the service. + tags: + - general + security: + - APIKeyAuth: [] + responses: + 200: + description: > + Service is healthy. + content: + application/json: + schema: + type: object + properties: + version: + type: string + example: v2.40.1 + name: + type: string + example: GoTrue + description: + type: string + example: GoTrue is a user registration and authentication API + + 500: + description: > + Service is not healthy. Retriable with exponential backoff. + 502: + description: > + Service is not healthy: infrastructure issue. Usually not retriable. + 503: + description: > + Service is not healthy: infrastrucutre issue. Retriable with exponential backoff. + 504: + description: > + Service is not healthy: request timed out. Retriable with exponential backoff. + + /settings: + get: + summary: Retrieve some of the public settings of the server. + description: > + Use this endpoint to configure parts of any authentication UIs depending on the configured settings. + tags: + - general + security: + - APIKeyAuth: [] + responses: + 200: + description: > + Currently applicable settings of the server. + content: + application/json: + schema: + type: object + properties: + disable_signup: + type: boolean + example: false + description: Whether new accounts can be created. (Valid for all providers.) + mailer_autoconfirm: + type: boolean + example: false + description: Whether new email addresses need to be confirmed before sign-in is possible. + phone_autoconfirm: + type: boolean + example: false + description: Whether new phone numbers need to be confirmed before sign-in is possible. + sms_provider: + type: string + optional: true + example: twilio + description: Which SMS provider is being used to send messages to phone numbers. + saml_enabled: + type: boolean + example: true + description: Whether SAML is enabled on this API server. Defaults to false. + external: + type: object + description: Which external identity providers are enabled. + example: + github: true + apple: true + email: true + phone: true + patternProperties: + "[a-zA-Z0-9]+": + type: boolean + +components: + securitySchemes: + UserAuth: + type: http + scheme: bearer + description: > + An access token in the form of a JWT issued by this server. + + AdminAuth: + type: http + scheme: bearer + description: > + A special admin JWT. + + APIKeyAuth: + type: apiKey + in: header + name: apikey + description: > + When deployed on Supabase, this server requires an `apikey` header containing a valid Supabase-issued API key to call any endpoint. + + schemas: + GoTrueMetaSecurity: + type: object + description: > + Use this property to pass a CAPTCHA token only if you have enabled CAPTCHA protection. + properties: + captcha_token: + type: string + + ErrorSchema: + type: object + properties: + error: + type: string + description: |- + Certain responses will contain this property with the provided values. + + Usually one of these: + - invalid_request + - unauthorized_client + - access_denied + - server_error + - temporarily_unavailable + - unsupported_otp_type + error_description: + type: string + description: > + Certain responses that have an `error` property may have this property which describes the error. + code: + type: integer + description: > + The HTTP status code. Usually missing if `error` is present. + example: 400 + msg: + type: string + description: > + A basic message describing the problem with the request. Usually missing if `error` is present. + weak_password: + type: object + description: > + Only returned on the `/signup` endpoint if the password used is too weak. Inspect the `reasons` and `msg` property to identify the causes. + properties: + reasons: + type: array + items: + type: string + enum: + - length + - characters + - pwned + + UserSchema: + type: object + description: Object describing the user related to the issued access and refresh tokens. + properties: + id: + type: string + format: uuid + aud: + type: string + deprecated: true + role: + type: string + email: + type: string + description: User's primary contact email. In most cases you can uniquely identify a user by their email address, but not in all cases. + email_confirmed_at: + type: string + format: date-time + phone: + type: string + format: phone + description: User's primary contact phone number. In most cases you can uniquely identify a user by their phone number, but not in all cases. + phone_confirmed_at: + type: string + format: date-time + confirmation_sent_at: + type: string + format: date-time + confirmed_at: + type: string + format: date-time + recovery_sent_at: + type: string + format: date-time + new_email: + type: string + format: email + email_change_sent_at: + type: string + format: date-time + new_phone: + type: string + format: phone + phone_change_sent_at: + type: string + format: date-time + reauthentication_sent_at: + type: string + format: date-time + last_sign_in_at: + type: string + format: date-time + app_metadata: + type: object + user_metadata: + type: object + factors: + type: array + items: + $ref: "#/components/schemas/MFAFactorSchema" + identities: + type: array + items: + $ref: "#/components/schemas/IdentitySchema" + banned_until: + type: string + format: date-time + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time + deleted_at: + type: string + format: date-time + is_anonymous: + type: boolean + + SAMLAttributeMappingSchema: + type: object + properties: + keys: + type: object + patternProperties: + ".+": + type: object + properties: + name: + type: string + names: + type: array + items: + type: string + default: + oneOf: + - type: string + - type: number + - type: boolean + - type: object + + SSOProviderSchema: + type: object + properties: + id: + type: string + format: uuid + sso_domains: + type: array + items: + type: object + properties: + domain: + type: string + format: hostname + saml: + type: object + properties: + entity_id: + type: string + metadata_xml: + type: string + metadata_url: + type: string + attribute_mapping: + $ref: "#/components/schemas/SAMLAttributeMappingSchema" + + AccessTokenResponseSchema: + type: object + properties: + access_token: + type: string + description: A valid JWT that will expire in `expires_in` seconds. + refresh_token: + type: string + description: An opaque string that can be used once to obtain a new access and refresh token. + token_type: + type: string + description: What type of token this is. Only `bearer` returned, may change in the future. + expires_in: + type: integer + description: Number of seconds after which the `access_token` should be renewed by using the refresh token with the `refresh_token` grant type. + expires_at: + type: integer + description: UNIX timestamp after which the `access_token` should be renewed by using the refresh token with the `refresh_token` grant type. + weak_password: + type: object + description: Only returned on the `/token?grant_type=password` endpoint. When present, it indicates that the password used is weak. Inspect the `reasons` and/or `message` properties to identify why. + properties: + reasons: + type: array + items: + type: string + enum: + - length + - characters + - pwned + message: + type: string + user: + $ref: "#/components/schemas/UserSchema" + + MFAFactorSchema: + type: object + description: Represents a MFA factor. + properties: + id: + type: string + format: uuid + status: + type: string + description: |- + Usually one of: + - verified + - unverified + friendly_name: + type: string + factor_type: + type: string + description: |- + Usually one of: + - totp + - phone + - webauthn + web_authn_credential: + type: jsonb + phone: + type: string + format: phone + nullable: true + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time + last_challenged_at: + type: string + format: date-time + nullable: true + + + IdentitySchema: + type: object + properties: + identity_id: + type: string + format: uuid + id: + type: string + format: uuid + user_id: + type: string + format: uuid + identity_data: + type: object + provider: + type: string + last_sign_in_at: + type: string + format: date-time + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time + email: + type: string + format: email + TOTPPhoneChallengeResponse: + type: object + required: + - id + - type + - expires_at + properties: + id: + type: string + format: uuid + example: 14c1560e-2749-4522-bb62-d1458451830a + description: ID of the challenge. + type: + type: string + enum: [totp, phone] + description: Type of the challenge. + expires_at: + type: integer + example: 1674840917 + description: UNIX seconds of the timestamp past which the challenge should not be verified. + + WebAuthnChallengeResponse: + type: object + required: + - id + - type + - expires_at + - credential_options + properties: + id: + type: string + format: uuid + example: 14c1560e-2749-4522-bb62-d1458451830a + description: ID of the challenge. + type: + type: string + enum: [webauthn] + description: Type of the challenge. + expires_at: + type: integer + example: 1674840917 + description: UNIX seconds of the timestamp past which the challenge should not be verified. + credential_request_options: + $ref: '#/components/schemas/CredentialRequestOptions' + credential_creation_options: + $ref: '#/components/schemas/CredentialCreationOptions' + + CredentialAssertion: + type: object + description: WebAuthn credential assertion options + required: + - challenge + - rpId + - allowCredentials + - timeout + properties: + challenge: + type: string + description: A random challenge generated by the server, base64url encoded + example: "Y2hhbGxlbmdlAyv-5P0kw1SG-OxhLbSHpRLdWaVR1w" + rpId: + type: string + description: The relying party's identifier (usually the domain name) + example: "example.com" + allowCredentials: + type: array + description: List of credentials acceptable for this authentication + items: + type: object + required: + - id + - type + properties: + id: + type: string + description: Credential ID, base64url encoded + example: "AXwyVxYT7BgNKwNq0YqUXaHHIdRK6OdFGCYgZF9K6zNu" + type: + type: string + enum: [public-key] + description: Type of the credential + timeout: + type: integer + description: Time (in milliseconds) that the user has to respond to the authentication prompt + example: 60000 + userVerification: + type: string + enum: [required, preferred, discouraged] + description: The relying party's requirements for user verification + default: preferred + extensions: + type: object + description: Additional parameters requesting additional processing by the client + status: + type: string + enum: [ok, failed] + description: Status of the credential assertion + errorMessage: + type: string + description: Error message if the assertion failed + userHandle: + type: string + description: User handle, base64url encoded + authenticatorAttachment: + type: string + enum: [platform, cross-platform] + description: Type of authenticator to use + + CredentialRequest: + type: object + description: WebAuthn credential request (for the response from the client) + required: + - id + - rawId + - type + - response + properties: + id: + type: string + description: Base64url encoding of the credential ID + example: "AXwyVxYT7BgNKwNq0YqUXaHHIdRK6OdFGCYgZF9K6zNu" + rawId: + type: string + description: Base64url encoding of the credential ID (same as id) + example: "AXwyVxYT7BgNKwNq0YqUXaHHIdRK6OdFGCYgZF9K6zNu" + type: + type: string + enum: [public-key] + description: Type of the credential + response: + type: object + required: + - clientDataJSON + - authenticatorData + - signature + - userHandle + properties: + clientDataJSON: + type: string + description: Base64url encoding of the client data + example: "eyJ0eXBlIjoid2ViYXV0aG4uZ2V0IiwiY2hhbGxlbmdlIjoiY2hhbGxlbmdlIiwib3JpZ2luIjoiaHR0cHM6Ly9leGFtcGxlLmNvbSJ9" + authenticatorData: + type: string + description: Base64url encoding of the authenticator data + example: "SZYN5YgOjGh0NBcPZHZgW4_krrmihjLHmVzzuoMdl2MBAAAAAAAAAAAAAAAAAAAAAAAAAAAAQAXwyVxYT7BgNKwNq0YqUXaHHIdRK6OdFGCYgZF9K6zNu" + signature: + type: string + description: Base64url encoding of the signature + example: "MEUCIQCx5cJVAB3kGP6bqCIoAV6CkBpVAf8rcx0WSZ22fIxXvQIgCKFt9pEu1vK8U4JKYTfn6tGjvGNfx2F4uXrHSXlefvM" + userHandle: + type: string + description: Base64url encoding of the user handle + example: "MQ" + clientExtensionResults: + type: object + description: Client extension results + + CredentialRequestOptions: + type: object + description: Options for requesting an assertion + properties: + challenge: + type: string + format: byte + description: A challenge to be signed by the authenticator + timeout: + type: integer + description: Time (in milliseconds) that the caller is willing to wait for the call to complete + rpId: + type: string + description: Relying Party ID + allowCredentials: + type: array + items: + $ref: '#/components/schemas/PublicKeyCredentialDescriptor' + userVerification: + type: string + enum: [required, preferred, discouraged] + description: User verification requirement + + CredentialCreationOptions: + type: object + description: Options for creating a new credential + properties: + rp: + type: object + properties: + id: + type: string + name: + type: string + user: + $ref: '#/components/schemas/UserSchema' + + challenge: + type: string + format: byte + description: A challenge to be signed by the authenticator + pubKeyCredParams: + type: array + items: + type: object + properties: + type: + type: string + enum: [public-key] + alg: + type: integer + timeout: + type: integer + description: Time (in milliseconds) that the caller is willing to wait for the call to complete + excludeCredentials: + type: array + items: + $ref: '#/components/schemas/PublicKeyCredentialDescriptor' + authenticatorSelection: + type: object + properties: + authenticatorAttachment: + type: string + enum: [platform, cross-platform] + requireResidentKey: + type: boolean + userVerification: + type: string + enum: [required, preferred, discouraged] + attestation: + type: string + enum: [none, indirect, direct] + description: Preferred attestation conveyance + + PublicKeyCredentialDescriptor: + type: object + properties: + type: + type: string + enum: [public-key] + id: + type: string + format: byte + description: Credential ID + transports: + type: array + items: + type: string + enum: [usb, nfc, ble, internal] + + responses: + OAuthCallbackRedirectResponse: + description: > + HTTP Redirect to a URL containing the `error` and `error_description` query parameters which should be shown to the user requesting the OAuth sign-in flow. + headers: + Location: + description: > + URL containing the `error` and `error_description` query parameters. + schema: + type: string + format: uri + example: https://example.com/?error=server_error&error_description=User%20does%20not%20exist. + + OAuthAuthorizeRedirectResponse: + description: > + HTTP Redirect to the OAuth identity provider's authorization URL. + headers: + Location: + description: > + URL to which the user agent should redirect (or open in a browser for mobile apps). + schema: + type: string + format: uri + + RateLimitResponse: + description: > + HTTP Too Many Requests response, when a rate limiter has been breached. + content: + application/json: + schema: + type: object + properties: + code: + type: integer + example: 429 + msg: + type: string + description: A basic message describing the rate limit breach. Do not use as an error code identifier. + example: Too many requests. Please try again in a few seconds. + + BadRequestResponse: + description: > + HTTP Bad Request response. Can occur if the passed in JSON cannot be unmarshalled properly or when CAPTCHA verification was not successful. In certain cases can also occur when features are disabled on the server (e.g. sign ups). It may also mean that the operation failed due to some constraint not being met (such a user already exists for example). + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + + UnauthorizedResponse: + description: > + HTTP Unauthorizred response. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + + ForbiddenResponse: + description: > + HTTP Forbidden response. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + + InternalServerErrorResponse: + description: > + HTTP Internal Server Error. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorSchema" + + AccessRefreshTokenRedirectResponse: + description: > + HTTP See Other redirect response where `Location` is a specially formatted URL that includes an `access_token`, `refresh_token`, `expires_in` as URL query encoded values in the URL fragment (anything after `#`). These values are encoded in the fragment as this value is only visible to the browser handling the redirect and is not sent to the server. + headers: + Location: + schema: + type: string + format: uri + example: https://example.com/#access_token=...&refresh_token=...&expires_in=... diff --git a/chatdesk-ui/.env.local.example b/chatdesk-ui/.env.local.example new file mode 100644 index 0000000..7cff1a7 --- /dev/null +++ b/chatdesk-ui/.env.local.example @@ -0,0 +1,36 @@ +# Supabase Public +NEXT_PUBLIC_SUPABASE_URL= +NEXT_PUBLIC_SUPABASE_ANON_KEY= + +# Supabase Private +SUPABASE_SERVICE_ROLE_KEY= + +# Ollama +NEXT_PUBLIC_OLLAMA_URL=http://localhost:11434 + +# API Keys (Optional: Entering an API key here overrides the API keys globally for all users.) +OPENAI_API_KEY= +ANTHROPIC_API_KEY= +GOOGLE_GEMINI_API_KEY= +MISTRAL_API_KEY= +GROQ_API_KEY= +PERPLEXITY_API_KEY= +OPENROUTER_API_KEY= + +# OpenAI API Information +NEXT_PUBLIC_OPENAI_ORGANIZATION_ID= + +# Azure API Information +AZURE_OPENAI_API_KEY= +AZURE_OPENAI_ENDPOINT= +AZURE_GPT_35_TURBO_NAME= +AZURE_GPT_45_VISION_NAME= +AZURE_GPT_45_TURBO_NAME= +AZURE_EMBEDDINGS_NAME= + +# General Configuration (Optional) +EMAIL_DOMAIN_WHITELIST= +EMAIL_WHITELIST= + +# File size limit for uploads in bytes +NEXT_PUBLIC_USER_FILE_SIZE_LIMIT=10485760 \ No newline at end of file diff --git a/chatdesk-ui/.eslintrc.json b/chatdesk-ui/.eslintrc.json new file mode 100644 index 0000000..6ec5479 --- /dev/null +++ b/chatdesk-ui/.eslintrc.json @@ -0,0 +1,25 @@ +{ + "$schema": "https://json.schemastore.org/eslintrc", + "root": true, + "extends": [ + "next/core-web-vitals", + "prettier", + "plugin:tailwindcss/recommended" + ], + "plugins": ["tailwindcss"], + "rules": { + "tailwindcss/no-custom-classname": "off" + }, + "settings": { + "tailwindcss": { + "callees": ["cn", "cva"], + "config": "tailwind.config.js" + } + }, + "overrides": [ + { + "files": ["*.ts", "*.tsx"], + "parser": "@typescript-eslint/parser" + } + ] +} diff --git a/chatdesk-ui/.gitignore b/chatdesk-ui/.gitignore new file mode 100644 index 0000000..f56e22b --- /dev/null +++ b/chatdesk-ui/.gitignore @@ -0,0 +1,46 @@ +# See https://help.github.com/articles/ignoring-files/ for more about ignoring files. + +# dependencies +/node_modules +/.pnp +.pnp.js +.yarn/install-state.gz + +# testing +/coverage + +# next.js +/.next/ +/out/ + +# production +/build + +# misc +.DS_Store +*.pem + +# debug +npm-debug.log* +yarn-debug.log* +yarn-error.log* + +# local env files +.env +.env*.local + +# vercel +.vercel + +# typescript +*.tsbuildinfo +next-env.d.ts + +.VSCodeCounter +tool-schemas +custom-prompts + +sw.js +sw.js.map +workbox-*.js +workbox-*.js.map diff --git a/chatdesk-ui/.nvmrc b/chatdesk-ui/.nvmrc new file mode 100644 index 0000000..7ea6a59 --- /dev/null +++ b/chatdesk-ui/.nvmrc @@ -0,0 +1 @@ +v20.11.0 diff --git a/chatdesk-ui/Dockerfile b/chatdesk-ui/Dockerfile new file mode 100644 index 0000000..7bd9cfb --- /dev/null +++ b/chatdesk-ui/Dockerfile @@ -0,0 +1,45 @@ +# ===== 构建阶段 ===== +FROM node:18.20.6 AS builder + +# 安装指定版本的 npm 和 pm2 +RUN npm install -g npm@10.8.2 \ + && npm install -g pm2@5.4.3 + +WORKDIR /app + +# 拷贝依赖文件并安装生产依赖 +COPY package.json package-lock.json ./ +RUN npm ci + +# 拷贝全部源码 +COPY . . + +# 构建项目 +RUN npm run build + + +# ===== 运行阶段 ===== +FROM node:18.20.6 AS runner + +# 安装指定版本的 npm 和 pm2 +RUN npm install -g npm@10.8.2 \ + && npm install -g pm2@5.4.3 + +WORKDIR /app + +# 拷贝依赖声明并安装仅生产依赖 +COPY package.json package-lock.json ./ +RUN npm ci --omit=dev + +# 拷贝构建产物和依赖 +COPY --from=builder /app/.next ./.next +COPY --from=builder /app/public ./public +COPY --from=builder /app/node_modules ./node_modules +COPY --from=builder /app/package.json ./package.json + +# 环境变量与端口 +ENV NODE_ENV=production +EXPOSE 3000 + +# 正确使用 pm2-runtime 保持容器挂起 +CMD ["pm2-runtime", "start", "npm", "--name", "chatai-ui", "--", "run", "start"] diff --git a/chatdesk-ui/README.md b/chatdesk-ui/README.md new file mode 100644 index 0000000..1d50749 --- /dev/null +++ b/chatdesk-ui/README.md @@ -0,0 +1,292 @@ +# Chatbot UI + +The open-source AI chat app for everyone. + +Chatbot UI + +## Demo + +View the latest demo [here](https://x.com/mckaywrigley/status/1738273242283151777?s=20). + +## Updates + +Hey everyone! I've heard your feedback and am working hard on a big update. + +Things like simpler deployment, better backend compatibility, and improved mobile layouts are on their way. + +Be back soon. + +-- Mckay + +## Official Hosted Version + +Use Chatbot UI without having to host it yourself! + +Find the official hosted version of Chatbot UI [here](https://chatbotui.com). + +## Sponsor + +If you find Chatbot UI useful, please consider [sponsoring](https://github.com/sponsors/mckaywrigley) me to support my open-source work :) + +## Issues + +We restrict "Issues" to actual issues related to the codebase. + +We're getting excessive amounts of issues that amount to things like feature requests, cloud provider issues, etc. + +If you are having issues with things like setup, please refer to the "Help" section in the "Discussions" tab above. + +Issues unrelated to the codebase will likely be closed immediately. + +## Discussions + +We highly encourage you to participate in the "Discussions" tab above! + +Discussions are a great place to ask questions, share ideas, and get help. + +Odds are if you have a question, someone else has the same question. + +## Legacy Code + +Chatbot UI was recently updated to its 2.0 version. + +The code for 1.0 can be found on the `legacy` branch. + +## Updating + +In your terminal at the root of your local Chatbot UI repository, run: + +```bash +npm run update +``` + +If you run a hosted instance you'll also need to run: + +```bash +npm run db-push +``` + +to apply the latest migrations to your live database. + +## Local Quickstart + +Follow these steps to get your own Chatbot UI instance running locally. + +You can watch the full video tutorial [here](https://www.youtube.com/watch?v=9Qq3-7-HNgw). + +### 1. Clone the Repo + +```bash +git clone https://github.com/mckaywrigley/chatbot-ui.git +``` + +### 2. Install Dependencies + +Open a terminal in the root directory of your local Chatbot UI repository and run: + +```bash +npm install +``` + +### 3. Install Supabase & Run Locally + +#### Why Supabase? + +Previously, we used local browser storage to store data. However, this was not a good solution for a few reasons: + +- Security issues +- Limited storage +- Limits multi-modal use cases + +We now use Supabase because it's easy to use, it's open-source, it's Postgres, and it has a free tier for hosted instances. + +We will support other providers in the future to give you more options. + +#### 1. Install Docker + +You will need to install Docker to run Supabase locally. You can download it [here](https://docs.docker.com/get-docker) for free. + +#### 2. Install Supabase CLI + +**MacOS/Linux** + +```bash +brew install supabase/tap/supabase +``` + +**Windows** + +```bash +scoop bucket add supabase https://github.com/supabase/scoop-bucket.git +scoop install supabase +``` + +#### 3. Start Supabase + +In your terminal at the root of your local Chatbot UI repository, run: + +```bash +supabase start +``` + +### 4. Fill in Secrets + +#### 1. Environment Variables + +In your terminal at the root of your local Chatbot UI repository, run: + +```bash +cp .env.local.example .env.local +``` + +Get the required values by running: + +```bash +supabase status +``` + +Note: Use `API URL` from `supabase status` for `NEXT_PUBLIC_SUPABASE_URL` + +Now go to your `.env.local` file and fill in the values. + +If the environment variable is set, it will disable the input in the user settings. + +#### 2. SQL Setup + +In the 1st migration file `supabase/migrations/20240108234540_setup.sql` you will need to replace 2 values with the values you got above: + +- `project_url` (line 53): `http://supabase_kong_chatbotui:8000` (default) can remain unchanged if you don't change your `project_id` in the `config.toml` file +- `service_role_key` (line 54): You got this value from running `supabase status` + +This prevents issues with storage files not being deleted properly. + +### 5. Install Ollama (optional for local models) + +Follow the instructions [here](https://github.com/jmorganca/ollama#macos). + +### 6. Run app locally + +In your terminal at the root of your local Chatbot UI repository, run: + +```bash +npm run chat +``` + +Your local instance of Chatbot UI should now be running at [http://localhost:3000](http://localhost:3000). Be sure to use a compatible node version (i.e. v18). + +You can view your backend GUI at [http://localhost:54323/project/default/editor](http://localhost:54323/project/default/editor). + +## Hosted Quickstart + +Follow these steps to get your own Chatbot UI instance running in the cloud. + +Video tutorial coming soon. + +### 1. Follow Local Quickstart + +Repeat steps 1-4 in "Local Quickstart" above. + +You will want separate repositories for your local and hosted instances. + +Create a new repository for your hosted instance of Chatbot UI on GitHub and push your code to it. + +### 2. Setup Backend with Supabase + +#### 1. Create a new project + +Go to [Supabase](https://supabase.com/) and create a new project. + +#### 2. Get Project Values + +Once you are in the project dashboard, click on the "Project Settings" icon tab on the far bottom left. + +Here you will get the values for the following environment variables: + +- `Project Ref`: Found in "General settings" as "Reference ID" + +- `Project ID`: Found in the URL of your project dashboard (Ex: https://supabase.com/dashboard/project//settings/general) + +While still in "Settings" click on the "API" text tab on the left. + +Here you will get the values for the following environment variables: + +- `Project URL`: Found in "API Settings" as "Project URL" + +- `Anon key`: Found in "Project API keys" as "anon public" + +- `Service role key`: Found in "Project API keys" as "service_role" (Reminder: Treat this like a password!) + +#### 3. Configure Auth + +Next, click on the "Authentication" icon tab on the far left. + +In the text tabs, click on "Providers" and make sure "Email" is enabled. + +We recommend turning off "Confirm email" for your own personal instance. + +#### 4. Connect to Hosted DB + +Open up your repository for your hosted instance of Chatbot UI. + +In the 1st migration file `supabase/migrations/20240108234540_setup.sql` you will need to replace 2 values with the values you got above: + +- `project_url` (line 53): Use the `Project URL` value from above +- `service_role_key` (line 54): Use the `Service role key` value from above + +Now, open a terminal in the root directory of your local Chatbot UI repository. We will execute a few commands here. + +Login to Supabase by running: + +```bash +supabase login +``` + +Next, link your project by running the following command with the "Project ID" you got above: + +```bash +supabase link --project-ref +``` + +Your project should now be linked. + +Finally, push your database to Supabase by running: + +```bash +supabase db push +``` + +Your hosted database should now be set up! + +### 3. Setup Frontend with Vercel + +Go to [Vercel](https://vercel.com/) and create a new project. + +In the setup page, import your GitHub repository for your hosted instance of Chatbot UI. Within the project Settings, in the "Build & Development Settings" section, switch Framework Preset to "Next.js". + +In environment variables, add the following from the values you got above: + +- `NEXT_PUBLIC_SUPABASE_URL` +- `NEXT_PUBLIC_SUPABASE_ANON_KEY` +- `SUPABASE_SERVICE_ROLE_KEY` +- `NEXT_PUBLIC_OLLAMA_URL` (only needed when using local Ollama models; default: `http://localhost:11434`) + +You can also add API keys as environment variables. + +- `OPENAI_API_KEY` +- `AZURE_OPENAI_API_KEY` +- `AZURE_OPENAI_ENDPOINT` +- `AZURE_GPT_45_VISION_NAME` + +For the full list of environment variables, refer to the '.env.local.example' file. If the environment variables are set for API keys, it will disable the input in the user settings. + +Click "Deploy" and wait for your frontend to deploy. + +Once deployed, you should be able to use your hosted instance of Chatbot UI via the URL Vercel gives you. + +## Contributing + +We are working on a guide for contributing. + +## Contact + +Message Mckay on [Twitter/X](https://twitter.com/mckaywrigley) diff --git a/chatdesk-ui/__tests__/lib/openapi-conversion.test.ts b/chatdesk-ui/__tests__/lib/openapi-conversion.test.ts new file mode 100644 index 0000000..7505590 --- /dev/null +++ b/chatdesk-ui/__tests__/lib/openapi-conversion.test.ts @@ -0,0 +1,369 @@ +import { openapiToFunctions } from "@/lib/openapi-conversion" + +const validSchemaURL = JSON.stringify({ + openapi: "3.1.0", + info: { + title: "Get weather data", + description: "Retrieves current weather data for a location.", + version: "v1.0.0" + }, + servers: [ + { + url: "https://weather.example.com" + } + ], + paths: { + "/location": { + get: { + description: "Get temperature for a specific location", + operationId: "GetCurrentWeather", + parameters: [ + { + name: "location", + in: "query", + description: "The city and state to retrieve the weather for", + required: true, + schema: { + type: "string" + } + } + ] + } + }, + "/summary": { + get: { + description: "Get description of weather for a specific location", + operationId: "GetWeatherSummary", + parameters: [ + { + name: "location", + in: "query", + description: "The city and state to retrieve the summary for", + required: true, + schema: { + type: "string" + } + } + ] + } + } + } +}) + +describe("extractOpenapiData for url", () => { + it("should parse a valid OpenAPI url schema", async () => { + const { info, routes, functions } = await openapiToFunctions( + JSON.parse(validSchemaURL) + ) + + expect(info.title).toBe("Get weather data") + expect(info.description).toBe( + "Retrieves current weather data for a location." + ) + expect(info.server).toBe("https://weather.example.com") + + expect(routes).toHaveLength(2) + + expect(functions).toHaveLength(2) + expect(functions[0].function.name).toBe("GetCurrentWeather") + expect(functions[1].function.name).toBe("GetWeatherSummary") + }) +}) + +const validSchemaBody = JSON.stringify({ + openapi: "3.1.0", + info: { + title: "Get weather data", + description: "Retrieves current weather data for a location.", + version: "v1.0.0" + }, + servers: [ + { + url: "https://weather.example.com" + } + ], + paths: { + "/location": { + post: { + description: "Get temperature for a specific location", + operationId: "GetCurrentWeather", + requestBody: { + required: true, + content: { + "application/json": { + schema: { + type: "object", + properties: { + location: { + type: "string", + description: + "The city and state to retrieve the weather for", + example: "New York, NY" + } + } + } + } + } + } + } + } + } +}) + +describe("extractOpenapiData for body", () => { + it("should parse a valid OpenAPI body schema", async () => { + const { info, routes, functions } = await openapiToFunctions( + JSON.parse(validSchemaBody) + ) + + expect(info.title).toBe("Get weather data") + expect(info.description).toBe( + "Retrieves current weather data for a location." + ) + expect(info.server).toBe("https://weather.example.com") + + expect(routes).toHaveLength(1) + expect(routes[0].path).toBe("/location") + expect(routes[0].method).toBe("post") + expect(routes[0].operationId).toBe("GetCurrentWeather") + + expect(functions).toHaveLength(1) + expect( + functions[0].function.parameters.properties.requestBody.properties + .location.type + ).toBe("string") + expect( + functions[0].function.parameters.properties.requestBody.properties + .location.description + ).toBe("The city and state to retrieve the weather for") + }) +}) + +const validSchemaBody2 = JSON.stringify({ + openapi: "3.1.0", + info: { + title: "Polygon.io Stock and Crypto Data API", + description: + "API schema for accessing stock and crypto data from Polygon.io.", + version: "1.0.0" + }, + servers: [ + { + url: "https://api.polygon.io" + } + ], + paths: { + "/v1/open-close/{stocksTicker}/{date}": { + get: { + summary: "Get Stock Daily Open and Close", + description: "Get the daily open and close for a specific stock.", + operationId: "getStockDailyOpenClose", + parameters: [ + { + name: "stocksTicker", + in: "path", + required: true, + schema: { + type: "string" + } + }, + { + name: "date", + in: "path", + required: true, + schema: { + type: "string", + format: "date" + } + } + ] + } + }, + "/v2/aggs/ticker/{stocksTicker}/prev": { + get: { + summary: "Get Stock Previous Close", + description: "Get the previous closing data for a specific stock.", + operationId: "getStockPreviousClose", + parameters: [ + { + name: "stocksTicker", + in: "path", + required: true, + schema: { + type: "string" + } + } + ] + } + }, + "/v3/trades/{stockTicker}": { + get: { + summary: "Get Stock Trades", + description: "Retrieve trades for a specific stock.", + operationId: "getStockTrades", + parameters: [ + { + name: "stockTicker", + in: "path", + required: true, + schema: { + type: "string" + } + } + ] + } + }, + "/v3/trades/{optionsTicker}": { + get: { + summary: "Get Options Trades", + description: "Retrieve trades for a specific options ticker.", + operationId: "getOptionsTrades", + parameters: [ + { + name: "optionsTicker", + in: "path", + required: true, + schema: { + type: "string" + } + } + ] + } + }, + "/v2/last/trade/{optionsTicker}": { + get: { + summary: "Get Last Options Trade", + description: "Get the last trade for a specific options ticker.", + operationId: "getLastOptionsTrade", + parameters: [ + { + name: "optionsTicker", + in: "path", + required: true, + schema: { + type: "string" + } + } + ] + } + }, + "/v1/open-close/crypto/{from}/{to}/{date}": { + get: { + summary: "Get Crypto Daily Open and Close", + description: + "Get daily open and close data for a specific cryptocurrency.", + operationId: "getCryptoDailyOpenClose", + parameters: [ + { + name: "from", + in: "path", + required: true, + schema: { + type: "string" + } + }, + { + name: "to", + in: "path", + required: true, + schema: { + type: "string" + } + }, + { + name: "date", + in: "path", + required: true, + schema: { + type: "string", + format: "date" + } + } + ] + } + }, + "/v2/aggs/ticker/{cryptoTicker}/prev": { + get: { + summary: "Get Crypto Previous Close", + description: + "Get the previous closing data for a specific cryptocurrency.", + operationId: "getCryptoPreviousClose", + parameters: [ + { + name: "cryptoTicker", + in: "path", + required: true, + schema: { + type: "string" + } + } + ] + } + } + }, + components: { + securitySchemes: { + BearerAuth: { + type: "http", + scheme: "bearer", + bearerFormat: "API Key" + } + } + }, + security: [ + { + BearerAuth: [] + } + ] +}) + +describe("extractOpenapiData for body 2", () => { + it("should parse a valid OpenAPI body schema for body 2", async () => { + const { info, routes, functions } = await openapiToFunctions( + JSON.parse(validSchemaBody2) + ) + + expect(info.title).toBe("Polygon.io Stock and Crypto Data API") + expect(info.description).toBe( + "API schema for accessing stock and crypto data from Polygon.io." + ) + expect(info.server).toBe("https://api.polygon.io") + + expect(routes).toHaveLength(7) + expect(routes[0].path).toBe("/v1/open-close/{stocksTicker}/{date}") + expect(routes[0].method).toBe("get") + expect(routes[0].operationId).toBe("getStockDailyOpenClose") + + expect(functions[0].function.parameters.properties).toHaveProperty( + "stocksTicker" + ) + expect(functions[0].function.parameters.properties.stocksTicker.type).toBe( + "string" + ) + expect( + functions[0].function.parameters.properties.stocksTicker + ).toHaveProperty("required", true) + expect(functions[0].function.parameters.properties).toHaveProperty("date") + expect(functions[0].function.parameters.properties.date.type).toBe("string") + expect(functions[0].function.parameters.properties.date).toHaveProperty( + "format", + "date" + ) + expect(functions[0].function.parameters.properties.date).toHaveProperty( + "required", + true + ) + expect(routes[1].path).toBe("/v2/aggs/ticker/{stocksTicker}/prev") + expect(routes[1].method).toBe("get") + expect(routes[1].operationId).toBe("getStockPreviousClose") + expect(functions[1].function.parameters.properties).toHaveProperty( + "stocksTicker" + ) + expect(functions[1].function.parameters.properties.stocksTicker.type).toBe( + "string" + ) + expect( + functions[1].function.parameters.properties.stocksTicker + ).toHaveProperty("required", true) + }) +}) diff --git a/chatdesk-ui/__tests__/playwright-test/.gitignore b/chatdesk-ui/__tests__/playwright-test/.gitignore new file mode 100644 index 0000000..68c5d18 --- /dev/null +++ b/chatdesk-ui/__tests__/playwright-test/.gitignore @@ -0,0 +1,5 @@ +node_modules/ +/test-results/ +/playwright-report/ +/blob-report/ +/playwright/.cache/ diff --git a/chatdesk-ui/__tests__/playwright-test/package-lock.json b/chatdesk-ui/__tests__/playwright-test/package-lock.json new file mode 100644 index 0000000..6b2675d --- /dev/null +++ b/chatdesk-ui/__tests__/playwright-test/package-lock.json @@ -0,0 +1,91 @@ +{ + "name": "playwright-test", + "version": "1.0.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "playwright-test", + "version": "1.0.0", + "license": "ISC", + "devDependencies": { + "@playwright/test": "^1.41.2", + "@types/node": "^20.11.20" + } + }, + "node_modules/@playwright/test": { + "version": "1.41.2", + "resolved": "https://registry.npmjs.org/@playwright/test/-/test-1.41.2.tgz", + "integrity": "sha512-qQB9h7KbibJzrDpkXkYvsmiDJK14FULCCZgEcoe2AvFAS64oCirWTwzTlAYEbKaRxWs5TFesE1Na6izMv3HfGg==", + "dev": true, + "dependencies": { + "playwright": "1.41.2" + }, + "bin": { + "playwright": "cli.js" + }, + "engines": { + "node": ">=16" + } + }, + "node_modules/@types/node": { + "version": "20.11.20", + "resolved": "https://registry.npmjs.org/@types/node/-/node-20.11.20.tgz", + "integrity": "sha512-7/rR21OS+fq8IyHTgtLkDK949uzsa6n8BkziAKtPVpugIkO6D+/ooXMvzXxDnZrmtXVfjb1bKQafYpb8s89LOg==", + "dev": true, + "dependencies": { + "undici-types": "~5.26.4" + } + }, + "node_modules/fsevents": { + "version": "2.3.2", + "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz", + "integrity": "sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==", + "dev": true, + "hasInstallScript": true, + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^8.16.0 || ^10.6.0 || >=11.0.0" + } + }, + "node_modules/playwright": { + "version": "1.41.2", + "resolved": "https://registry.npmjs.org/playwright/-/playwright-1.41.2.tgz", + "integrity": "sha512-v0bOa6H2GJChDL8pAeLa/LZC4feoAMbSQm1/jF/ySsWWoaNItvrMP7GEkvEEFyCTUYKMxjQKaTSg5up7nR6/8A==", + "dev": true, + "dependencies": { + "playwright-core": "1.41.2" + }, + "bin": { + "playwright": "cli.js" + }, + "engines": { + "node": ">=16" + }, + "optionalDependencies": { + "fsevents": "2.3.2" + } + }, + "node_modules/playwright-core": { + "version": "1.41.2", + "resolved": "https://registry.npmjs.org/playwright-core/-/playwright-core-1.41.2.tgz", + "integrity": "sha512-VaTvwCA4Y8kxEe+kfm2+uUUw5Lubf38RxF7FpBxLPmGe5sdNkSg5e3ChEigaGrX7qdqT3pt2m/98LiyvU2x6CA==", + "dev": true, + "bin": { + "playwright-core": "cli.js" + }, + "engines": { + "node": ">=16" + } + }, + "node_modules/undici-types": { + "version": "5.26.5", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz", + "integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==", + "dev": true + } + } +} diff --git a/chatdesk-ui/__tests__/playwright-test/package.json b/chatdesk-ui/__tests__/playwright-test/package.json new file mode 100644 index 0000000..286e30d --- /dev/null +++ b/chatdesk-ui/__tests__/playwright-test/package.json @@ -0,0 +1,18 @@ +{ + "name": "playwright-test", + "version": "1.0.0", + "description": "", + "main": "index.js", + "scripts": { + "integration": "playwright test", + "integration:open": "playwright test --ui", + "integration:codegen": "playwright codegen" + }, + "keywords": [], + "author": "", + "license": "ISC", + "devDependencies": { + "@playwright/test": "^1.41.2", + "@types/node": "^20.11.20" + } +} diff --git a/chatdesk-ui/__tests__/playwright-test/playwright.config.ts b/chatdesk-ui/__tests__/playwright-test/playwright.config.ts new file mode 100644 index 0000000..301801e --- /dev/null +++ b/chatdesk-ui/__tests__/playwright-test/playwright.config.ts @@ -0,0 +1,77 @@ +import { defineConfig, devices } from '@playwright/test'; + +/** + * Read environment variables from file. + * https://github.com/motdotla/dotenv + */ +// require('dotenv').config(); + +/** + * See https://playwright.dev/docs/test-configuration. + */ +export default defineConfig({ + testDir: './tests', + /* Run tests in files in parallel */ + fullyParallel: true, + /* Fail the build on CI if you accidentally left test.only in the source code. */ + forbidOnly: !!process.env.CI, + /* Retry on CI only */ + retries: process.env.CI ? 2 : 0, + /* Opt out of parallel tests on CI. */ + workers: process.env.CI ? 1 : undefined, + /* Reporter to use. See https://playwright.dev/docs/test-reporters */ + reporter: 'html', + /* Shared settings for all the projects below. See https://playwright.dev/docs/api/class-testoptions. */ + use: { + /* Base URL to use in actions like `await page.goto('/')`. */ + // baseURL: 'http://127.0.0.1:3000', + + /* Collect trace when retrying the failed test. See https://playwright.dev/docs/trace-viewer */ + trace: 'on-first-retry', + }, + + /* Configure projects for major browsers */ + projects: [ + { + name: 'chromium', + use: { ...devices['Desktop Chrome'] }, + }, + + { + name: 'firefox', + use: { ...devices['Desktop Firefox'] }, + }, + + { + name: 'webkit', + use: { ...devices['Desktop Safari'] }, + }, + + /* Test against mobile viewports. */ + // { + // name: 'Mobile Chrome', + // use: { ...devices['Pixel 5'] }, + // }, + // { + // name: 'Mobile Safari', + // use: { ...devices['iPhone 12'] }, + // }, + + /* Test against branded browsers. */ + // { + // name: 'Microsoft Edge', + // use: { ...devices['Desktop Edge'], channel: 'msedge' }, + // }, + // { + // name: 'Google Chrome', + // use: { ...devices['Desktop Chrome'], channel: 'chrome' }, + // }, + ], + + /* Run your local dev server before starting the tests */ + // webServer: { + // command: 'npm run start', + // url: 'http://127.0.0.1:3000', + // reuseExistingServer: !process.env.CI, + // }, +}); diff --git a/chatdesk-ui/__tests__/playwright-test/tests/login.spec.ts b/chatdesk-ui/__tests__/playwright-test/tests/login.spec.ts new file mode 100644 index 0000000..0b69a70 --- /dev/null +++ b/chatdesk-ui/__tests__/playwright-test/tests/login.spec.ts @@ -0,0 +1,46 @@ +import { test, expect } from '@playwright/test'; + +test('start chatting is displayed', async ({ page }) => { + await page.goto('http://localhost:3000/'); + + //expect the start chatting link to be visible + await expect (page.getByRole('link', { name: 'Start Chatting' })).toBeVisible(); +}); + +test('No password error message', async ({ page }) => { + await page.goto('http://localhost:3000/login'); + //fill in dummy email + await page.getByPlaceholder('you@example.com').fill('dummyemail@gmail.com'); + await page.getByRole('button', { name: 'Login' }).click(); + //wait for netwrok to be idle + await page.waitForLoadState('networkidle'); + //validate that correct message is shown to the user + await expect(page.getByText('Invalid login credentials')).toBeVisible(); + +}); +test('No password for signup', async ({ page }) => { + await page.goto('http://localhost:3000/login'); + + await page.getByPlaceholder('you@example.com').fill('dummyEmail@Gmail.com'); + await page.getByRole('button', { name: 'Sign Up' }).click(); + //validate appropriate error is thrown for missing password when signing up + await expect(page.getByText('Signup requires a valid')).toBeVisible(); +}); +test('invalid username for signup', async ({ page }) => { + await page.goto('http://localhost:3000/login'); + + await page.getByPlaceholder('you@example.com').fill('dummyEmail'); + await page.getByPlaceholder('••••••••').fill('dummypassword'); + await page.getByRole('button', { name: 'Sign Up' }).click(); + //validate appropriate error is thrown for invalid username when signing up + await expect(page.getByText('Unable to validate email')).toBeVisible(); +}); +test('password reset message', async ({ page }) => { + await page.goto('http://localhost:3000/login'); + await page.getByPlaceholder('you@example.com').fill('demo@gmail.com'); + await page.getByRole('button', { name: 'Reset' }).click(); + //validate appropriate message is shown + await expect(page.getByText('Check email to reset password')).toBeVisible(); +}); + +//more tests can be added here \ No newline at end of file diff --git a/chatdesk-ui/app/[locale]/[workspaceid]/chat/[chatid]/page.tsx b/chatdesk-ui/app/[locale]/[workspaceid]/chat/[chatid]/page.tsx new file mode 100644 index 0000000..30d082e --- /dev/null +++ b/chatdesk-ui/app/[locale]/[workspaceid]/chat/[chatid]/page.tsx @@ -0,0 +1,7 @@ +"use client" + +import { ChatUI } from "@/components/chat/chat-ui" + +export default function ChatIDPage() { + return +} diff --git a/chatdesk-ui/app/[locale]/[workspaceid]/chat/page.tsx b/chatdesk-ui/app/[locale]/[workspaceid]/chat/page.tsx new file mode 100644 index 0000000..8de866e --- /dev/null +++ b/chatdesk-ui/app/[locale]/[workspaceid]/chat/page.tsx @@ -0,0 +1,62 @@ +"use client" + +import { ChatHelp } from "@/components/chat/chat-help" +import { useChatHandler } from "@/components/chat/chat-hooks/use-chat-handler" +import { ChatInput } from "@/components/chat/chat-input" +import { ChatSettings } from "@/components/chat/chat-settings" +import { ChatUI } from "@/components/chat/chat-ui" +import { QuickSettings } from "@/components/chat/quick-settings" +import { Brand } from "@/components/ui/brand" +import { ChatbotUIContext } from "@/context/context" +import useHotkey from "@/lib/hooks/use-hotkey" +import { useTheme } from "next-themes" +import { useContext } from "react" + +import { useTranslation } from 'react-i18next' + +export default function ChatPage() { + useHotkey("o", () => handleNewChat()) + useHotkey("l", () => { + handleFocusChatInput() + }) + + const { chatMessages } = useContext(ChatbotUIContext) + + const { handleNewChat, handleFocusChatInput } = useChatHandler() + + const { theme } = useTheme() + + const { t } = useTranslation() + + return ( + <> + {chatMessages.length === 0 ? ( +
+
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+ +
+
+ ) : ( + + )} + + ) +} diff --git a/chatdesk-ui/app/[locale]/[workspaceid]/layout.tsx b/chatdesk-ui/app/[locale]/[workspaceid]/layout.tsx new file mode 100644 index 0000000..218cba6 --- /dev/null +++ b/chatdesk-ui/app/[locale]/[workspaceid]/layout.tsx @@ -0,0 +1,193 @@ +"use client" + +import { usePathname } from "next/navigation" + +import { Dashboard } from "@/components/ui/dashboard" +import { ChatbotUIContext } from "@/context/context" +import { getAssistantWorkspacesByWorkspaceId } from "@/db/assistants" +import { getChatsByWorkspaceId } from "@/db/chats" +import { getCollectionWorkspacesByWorkspaceId } from "@/db/collections" +import { getFileWorkspacesByWorkspaceId } from "@/db/files" +import { getFoldersByWorkspaceId } from "@/db/folders" +import { getModelWorkspacesByWorkspaceId } from "@/db/models" +import { getPresetWorkspacesByWorkspaceId } from "@/db/presets" +import { getPromptWorkspacesByWorkspaceId } from "@/db/prompts" +import { getAssistantImageFromStorage } from "@/db/storage/assistant-images" +import { getToolWorkspacesByWorkspaceId } from "@/db/tools" +import { getWorkspaceById } from "@/db/workspaces" +import { convertBlobToBase64 } from "@/lib/blob-to-b64" +import { supabase } from "@/lib/supabase/browser-client" +import { LLMID } from "@/types" +import { useParams, useRouter, useSearchParams } from "next/navigation" +import { ReactNode, useContext, useEffect, useState } from "react" +import Loading from "../loading" +import { useTranslation } from 'react-i18next' + +interface WorkspaceLayoutProps { + children: ReactNode +} + +export default function WorkspaceLayout({ children }: WorkspaceLayoutProps) { + const { t } = useTranslation() + + const router = useRouter() + const pathname = usePathname() // 获取当前路径 + + // 提取当前路径中的 locale 部分 + const locale = pathname.split("/")[1] || "en" + + const params = useParams() + const searchParams = useSearchParams() + const workspaceId = params.workspaceid as string + + const { + setChatSettings, + setAssistants, + setAssistantImages, + setChats, + setCollections, + setFolders, + setFiles, + setPresets, + setPrompts, + setTools, + setModels, + selectedWorkspace, + setSelectedWorkspace, + setSelectedChat, + setChatMessages, + setUserInput, + setIsGenerating, + setFirstTokenReceived, + setChatFiles, + setChatImages, + setNewMessageFiles, + setNewMessageImages, + setShowFilesDisplay + } = useContext(ChatbotUIContext) + + const [loading, setLoading] = useState(true) + + useEffect(() => { + ;(async () => { + const session = (await supabase.auth.getSession()).data.session + + if (!session) { + // 跳转到带有 locale 的登录页面 + return router.push(`/${locale}/login`) + } else { + await fetchWorkspaceData(workspaceId) + } + })() + }, []) + + useEffect(() => { + ;(async () => await fetchWorkspaceData(workspaceId))() + + setUserInput("") + setChatMessages([]) + setSelectedChat(null) + + setIsGenerating(false) + setFirstTokenReceived(false) + + setChatFiles([]) + setChatImages([]) + setNewMessageFiles([]) + setNewMessageImages([]) + setShowFilesDisplay(false) + }, [workspaceId]) + + const fetchWorkspaceData = async (workspaceId: string) => { + setLoading(true) + + const workspace = await getWorkspaceById(workspaceId) + setSelectedWorkspace(workspace) + + const assistantData = await getAssistantWorkspacesByWorkspaceId(workspaceId) + setAssistants(assistantData.assistants) + + for (const assistant of assistantData.assistants) { + let url = "" + + if (assistant.image_path) { + url = (await getAssistantImageFromStorage(assistant.image_path)) || "" + } + + if (url) { + const response = await fetch(url) + const blob = await response.blob() + const base64 = await convertBlobToBase64(blob) + + setAssistantImages(prev => [ + ...prev, + { + assistantId: assistant.id, + path: assistant.image_path, + base64, + url + } + ]) + } else { + setAssistantImages(prev => [ + ...prev, + { + assistantId: assistant.id, + path: assistant.image_path, + base64: "", + url + } + ]) + } + } + + const chats = await getChatsByWorkspaceId(workspaceId) + setChats(chats) + + const collectionData = + await getCollectionWorkspacesByWorkspaceId(workspaceId) + setCollections(collectionData.collections) + + const folders = await getFoldersByWorkspaceId(workspaceId) + setFolders(folders) + + const fileData = await getFileWorkspacesByWorkspaceId(workspaceId) + setFiles(fileData.files) + + const presetData = await getPresetWorkspacesByWorkspaceId(workspaceId) + setPresets(presetData.presets) + + const promptData = await getPromptWorkspacesByWorkspaceId(workspaceId) + setPrompts(promptData.prompts) + + const toolData = await getToolWorkspacesByWorkspaceId(workspaceId) + setTools(toolData.tools) + + const modelData = await getModelWorkspacesByWorkspaceId(workspaceId) + setModels(modelData.models) + + setChatSettings({ + model: (searchParams.get("model") || + workspace?.default_model || + "gpt-4-1106-preview") as LLMID, + prompt: + workspace?.default_prompt || + t("chat.promptPlaceholder"), + temperature: workspace?.default_temperature || 0.5, + contextLength: workspace?.default_context_length || 4096, + includeProfileContext: workspace?.include_profile_context || true, + includeWorkspaceInstructions: + workspace?.include_workspace_instructions || true, + embeddingsProvider: + (workspace?.embeddings_provider as "openai" | "local") || "openai" + }) + + setLoading(false) + } + + if (loading) { + return + } + + return {children} +} diff --git a/chatdesk-ui/app/[locale]/[workspaceid]/page.tsx b/chatdesk-ui/app/[locale]/[workspaceid]/page.tsx new file mode 100644 index 0000000..b43e8d5 --- /dev/null +++ b/chatdesk-ui/app/[locale]/[workspaceid]/page.tsx @@ -0,0 +1,14 @@ +"use client" + +import { ChatbotUIContext } from "@/context/context" +import { useContext } from "react" + +export default function WorkspacePage() { + const { selectedWorkspace } = useContext(ChatbotUIContext) + + return ( +
+
{selectedWorkspace?.name}
+
+ ) +} diff --git a/chatdesk-ui/app/[locale]/globals.css b/chatdesk-ui/app/[locale]/globals.css new file mode 100644 index 0000000..c0d1efc --- /dev/null +++ b/chatdesk-ui/app/[locale]/globals.css @@ -0,0 +1,104 @@ +@tailwind base; +@tailwind components; +@tailwind utilities; + +::-webkit-scrollbar-track { + background-color: transparent; +} + +::-webkit-scrollbar-thumb { + background-color: #ccc; + border-radius: 10px; +} + +::-webkit-scrollbar-thumb:hover { + background-color: #aaa; +} + +::-webkit-scrollbar-track:hover { + background-color: #f2f2f2; +} + +::-webkit-scrollbar-corner { + background-color: transparent; +} + +::-webkit-scrollbar { + width: 6px; + height: 6px; +} + +@layer base { + :root { + --background: 0 0% 100%; + --foreground: 0 0% 3.9%; + + --muted: 0 0% 96.1%; + --muted-foreground: 0 0% 45.1%; + + --popover: 0 0% 100%; + --popover-foreground: 0 0% 3.9%; + + --card: 0 0% 100%; + --card-foreground: 0 0% 3.9%; + + --border: 0 0% 89.8%; + --input: 0 0% 89.8%; + + --primary: 0 0% 9%; + --primary-foreground: 0 0% 98%; + + --secondary: 0 0% 96.1%; + --secondary-foreground: 0 0% 9%; + + --accent: 0 0% 96.1%; + --accent-foreground: 0 0% 9%; + + --destructive: 0 84.2% 60.2%; + --destructive-foreground: 0 0% 98%; + + --ring: 0 0% 63.9%; + + --radius: 0.5rem; + } + + .dark { + --background: 0 0% 3.9%; + --foreground: 0 0% 98%; + + --muted: 0 0% 14.9%; + --muted-foreground: 0 0% 63.9%; + + --popover: 0 0% 3.9%; + --popover-foreground: 0 0% 98%; + + --card: 0 0% 3.9%; + --card-foreground: 0 0% 98%; + + --border: 0 0% 14.9%; + --input: 0 0% 14.9%; + + --primary: 0 0% 98%; + --primary-foreground: 0 0% 9%; + + --secondary: 0 0% 14.9%; + --secondary-foreground: 0 0% 98%; + + --accent: 0 0% 14.9%; + --accent-foreground: 0 0% 98%; + + --destructive: 0 62.8% 30.6%; + --destructive-foreground: 0 85.7% 97.3%; + + --ring: 0 0% 14.9%; + } +} + +@layer base { + * { + @apply border-border; + } + body { + @apply bg-background text-foreground; + } +} diff --git a/chatdesk-ui/app/[locale]/help/page.tsx b/chatdesk-ui/app/[locale]/help/page.tsx new file mode 100644 index 0000000..c1753f4 --- /dev/null +++ b/chatdesk-ui/app/[locale]/help/page.tsx @@ -0,0 +1,7 @@ +export default function HelpPage() { + return ( +
+
Help under construction.
+
+ ) +} diff --git a/chatdesk-ui/app/[locale]/layout.tsx b/chatdesk-ui/app/[locale]/layout.tsx new file mode 100644 index 0000000..affc886 --- /dev/null +++ b/chatdesk-ui/app/[locale]/layout.tsx @@ -0,0 +1,170 @@ +import { Toaster } from "@/components/ui/sonner" +import { GlobalState } from "@/components/utility/global-state" +import { Providers } from "@/components/utility/providers" +import TranslationsProvider from "@/components/utility/translations-provider" +import initTranslations from "@/lib/i18n" +import { Database } from "@/supabase/types" +import { createServerClient } from "@supabase/ssr" +import { Metadata, Viewport } from "next" +import { Inter } from "next/font/google" +import { cookies } from "next/headers" +import { ReactNode } from "react" +import "./globals.css" + +const inter = Inter({ subsets: ["latin"] }) +const APP_NAME = "ChatAI UI" +const APP_DEFAULT_TITLE = "ChatAI UI" +const APP_TITLE_TEMPLATE = "%s - ChatAI UI" +const APP_DESCRIPTION = "ChaAI UI PWA!" + +interface RootLayoutProps { + children: ReactNode + params: { + locale: string + } +} + +// export const metadata: Metadata = { +// applicationName: APP_NAME, +// title: { +// default: APP_DEFAULT_TITLE, +// template: APP_TITLE_TEMPLATE +// }, +// description: APP_DESCRIPTION, +// manifest: "/manifest.json", +// appleWebApp: { +// capable: true, +// statusBarStyle: "black", +// title: APP_DEFAULT_TITLE +// // startUpImage: [], +// }, +// formatDetection: { +// telephone: false +// }, +// openGraph: { +// type: "website", +// siteName: APP_NAME, +// title: { +// default: APP_DEFAULT_TITLE, +// template: APP_TITLE_TEMPLATE +// }, +// description: APP_DESCRIPTION +// }, +// twitter: { +// card: "summary", +// title: { +// default: APP_DEFAULT_TITLE, +// template: APP_TITLE_TEMPLATE +// }, +// description: APP_DESCRIPTION +// } +// } + +export async function generateMetadata({ + params: { locale } +}: { + params: { locale: string } +}): Promise { + const { t } = await initTranslations(locale, ["translation"]) + + const appName = t("meta.appName") + const defaultTitle = t("meta.defaultTitle") + const description = t("meta.description") + const titleTemplate = `%s - ${defaultTitle}` + + return { + applicationName: appName, + title: { + default: defaultTitle, + template: titleTemplate + }, + description, + manifest: "/manifest.json", + appleWebApp: { + capable: true, + statusBarStyle: "black", + title: defaultTitle + }, + formatDetection: { + telephone: false + }, + openGraph: { + type: "website", + siteName: appName, + title: { + default: defaultTitle, + template: titleTemplate + }, + description + }, + twitter: { + card: "summary", + title: { + default: defaultTitle, + template: titleTemplate + }, + description + } + } +} + +export const viewport: Viewport = { + themeColor: "#000000" +} + +const i18nNamespaces = ["translation"] + +export default async function RootLayout({ + children, + params: { locale } +}: RootLayoutProps) { + const cookieStore = cookies() + + + // 遍历所有 cookies + for (const cookie of cookieStore.getAll()) { + console.log(`🍪 Cookie: ${cookie.name} = ${cookie.value}`); + } + + const supabase = createServerClient( + process.env.NEXT_PUBLIC_SUPABASE_URL!, + process.env.NEXT_PUBLIC_SUPABASE_ANON_KEY!, + { + cookies: { + get(name: string) { + return cookieStore.get(name)?.value + } + } + } + ) + // const session = (await supabase.auth.getSession()).data.session + const { data, error } = await supabase.auth.getSession(); + if (error) { + console.log("[layout.tsx]............Session Error: ", error); + } else { + console.log("[layout.tsx]............Session Data: ", data.session); + } + + const { t, resources } = await initTranslations(locale, i18nNamespaces) + + console.log("[layout.tsx]..............current locale: ", {locale}); + + return ( + + + + + +
+ {data.session ? {children} : children} +
+
+
+ + + ) +} diff --git a/chatdesk-ui/app/[locale]/loading.tsx b/chatdesk-ui/app/[locale]/loading.tsx new file mode 100644 index 0000000..4cfc63f --- /dev/null +++ b/chatdesk-ui/app/[locale]/loading.tsx @@ -0,0 +1,9 @@ +import { IconLoader2 } from "@tabler/icons-react" + +export default function Loading() { + return ( +
+ +
+ ) +} diff --git a/chatdesk-ui/app/[locale]/login/page.tsx b/chatdesk-ui/app/[locale]/login/page.tsx new file mode 100644 index 0000000..ae1c779 --- /dev/null +++ b/chatdesk-ui/app/[locale]/login/page.tsx @@ -0,0 +1,260 @@ +import { Brand } from "@/components/ui/brand" +import { Input } from "@/components/ui/input" +import { Label } from "@/components/ui/label" +import { SubmitButton } from "@/components/ui/submit-button" +import { createClient } from "@/lib/supabase/server" +import { Database } from "@/supabase/types" +import { createServerClient } from "@supabase/ssr" +import { get } from "@vercel/edge-config" +import { Metadata } from "next" +import { cookies, headers } from "next/headers" +import { redirect } from "next/navigation" + + +import initTranslations from "@/lib/i18n"; + +export const metadata: Metadata = { + title: "Login" +} + +export default async function Login({ + searchParams, + params: { locale }, +}: { + searchParams: { message: string; email?: string }; + params: { locale: string }; +}) { + const cookieStore = cookies() + + const localeString = locale; + const { t, resources } = await initTranslations(localeString, ['translation']); + + const supabase = createServerClient( + process.env.NEXT_PUBLIC_SUPABASE_URL!, + process.env.NEXT_PUBLIC_SUPABASE_ANON_KEY!, + { + cookies: { + get(name: string) { + return cookieStore.get(name)?.value + } + } + } + ) + const session = (await supabase.auth.getSession()).data.session + + console.log("[login page]Login session:", session) + + if (session) { + const { data: homeWorkspace, error } = await supabase + .from("workspaces") + .select("*") + .eq("user_id", session.user.id) + .eq("is_home", true) + .single() + + if (!homeWorkspace) { + throw new Error(error.message) + } + + // console.log("[login page]======>Redirecting to workspace:", homeWorkspace.id) + return redirect(`/${localeString}/${homeWorkspace.id}/chat`) + + } + + const signIn = async (formData: FormData) => { + "use server" + + const email = formData.get("email") as string + const password = formData.get("password") as string + const cookieStore = cookies() + const supabase = createClient(cookieStore) + + const { data, error } = await supabase.auth.signInWithPassword({ + email, + password + }) + + if (error) { + // console.log(`[login page]==================> ${localeString}/login?message=${error.message}`); + // return redirect(`/${localeString}/login?message=${error.message}`) + return redirect(`/${localeString}/login?message=invalidCredentials`) + } + + const { data: homeWorkspace, error: homeWorkspaceError } = await supabase + .from("workspaces") + .select("*") + .eq("user_id", data.user.id) + .eq("is_home", true) + .single() + + if (!homeWorkspace) { + //const fallbackMessage = String(t("login.unexpectedError")) + throw new Error( + homeWorkspaceError?.message || "An unexpected error occurred" + // homeWorkspaceError?.message || t("login.unexpectedError") + ) + } + + return redirect(`/${localeString}/${homeWorkspace.id}/chat`) + } + + const getEnvVarOrEdgeConfigValue = async (name: string) => { + "use server" + if (process.env.EDGE_CONFIG) { + return await get(name) + } + + return process.env[name] + } + + const signUp = async (formData: FormData) => { + "use server" + + + const email = formData.get("email") as string + const password = formData.get("password") as string + + const emailDomainWhitelistPatternsString = await getEnvVarOrEdgeConfigValue( + "EMAIL_DOMAIN_WHITELIST" + ) + const emailDomainWhitelist = emailDomainWhitelistPatternsString?.trim() + ? emailDomainWhitelistPatternsString?.split(",") + : [] + const emailWhitelistPatternsString = + await getEnvVarOrEdgeConfigValue("EMAIL_WHITELIST") + const emailWhitelist = emailWhitelistPatternsString?.trim() + ? emailWhitelistPatternsString?.split(",") + : [] + + // If there are whitelist patterns, check if the email is allowed to sign up + if (emailDomainWhitelist.length > 0 || emailWhitelist.length > 0) { + const domainMatch = emailDomainWhitelist?.includes(email.split("@")[1]) + const emailMatch = emailWhitelist?.includes(email) + if (!domainMatch && !emailMatch) { + return redirect( + // `/${localeString}/login?message=Email ${email} is not allowed to sign up.` + `/${localeString}/login?message=signupNotAllowed&email=${encodeURIComponent(email)}` + ) + } + } + + const cookieStore = cookies() + const supabase = createClient(cookieStore) + + const { error } = await supabase.auth.signUp({ + email, + password, + options: { + // USE IF YOU WANT TO SEND EMAIL VERIFICATION, ALSO CHANGE TOML FILE + // emailRedirectTo: `${origin}/auth/callback` + } + }) + + if (error) { + console.error(error) + return redirect(`/${localeString}/login?message=${error.message}`) + } + + return redirect(`/${localeString}/setup`) + + // USE IF YOU WANT TO SEND EMAIL VERIFICATION, ALSO CHANGE TOML FILE + // return redirect("/login?message=Check email to continue sign in process") + } + + const handleResetPassword = async (formData: FormData) => { + "use server" + + const origin = headers().get("origin") + const email = formData.get("email") as string + const cookieStore = cookies() + const supabase = createClient(cookieStore) + + const { error } = await supabase.auth.resetPasswordForEmail(email, { + redirectTo: `${origin}/auth/callback?next=${localeString}/login/password` + }) + + if (error) { + return redirect(`/${localeString}/login?message=${error.message}`) + } + + // const emailtoResetMessage = String(t("login.checkEmailToReset")) // ← 这是字符串 + // return redirect(`/${localeString}/login?message=${emailtoResetMessage}`) + return redirect(`/${localeString}/login?message=Check email to reset password`) + } + + + let translatedMessage: string | null = null; + + if (searchParams.message === "signupNotAllowed") { + translatedMessage = t("login.signupNotAllowed", { email: searchParams.email }); + } else if (searchParams.message === "signupNotAllowed") { + + } else if (searchParams.message) { + translatedMessage = t(`login.${searchParams.message}`); + } + + return ( +
+
+ + + + + + + + + + {t("login.loginButton")} + + + + {t("login.signUpButton")} + + +
+ {t("login.forgotPassword")} + +
+ + {/* {searchParams?.message && ( +

+ {searchParams.message} +

+ )} */} + + {translatedMessage && ( +

+ {translatedMessage} +

+ )} + + +
+ ) +} diff --git a/chatdesk-ui/app/[locale]/login/password/page.tsx b/chatdesk-ui/app/[locale]/login/password/page.tsx new file mode 100644 index 0000000..ddee596 --- /dev/null +++ b/chatdesk-ui/app/[locale]/login/password/page.tsx @@ -0,0 +1,55 @@ +"use client" + +import { ChangePassword } from "@/components/utility/change-password" +import { supabase } from "@/lib/supabase/browser-client" +import { useRouter } from "next/navigation" +import { useEffect, useState } from "react" + +import { usePathname } from "next/navigation" // 导入 usePathname + +import i18nConfig from "@/i18nConfig" + +export default function ChangePasswordPage() { + const [loading, setLoading] = useState(true) + + const router = useRouter() + const pathname = usePathname() // 获取当前路径 + + useEffect(() => { + ;(async () => { + const session = (await supabase.auth.getSession()).data.session + + if (!session) { + // // 提取当前路径中的 locale 部分 + // const locale = pathname.split("/")[1] || "en" // 获取路径中的 locale 部分,如果没有则默认为 "en" + + const pathSegments = pathname.split("/").filter(Boolean) + const locales = i18nConfig.locales + const defaultLocale = i18nConfig.defaultLocale + + let locale: (typeof locales)[number] = defaultLocale + + const segment = pathSegments[0] as (typeof locales)[number] + + if (locales.includes(segment)) { + locale = segment + } + const homePath = locale === defaultLocale ? "/" : `/${locale}` + + + + console.log("...........[login page.tsx]") + router.push(`${homePath}/login`) + // router.push(`${locale}/login`) + } else { + setLoading(false) + } + })() + }, []) + + if (loading) { + return null + } + + return +} diff --git a/chatdesk-ui/app/[locale]/page.tsx b/chatdesk-ui/app/[locale]/page.tsx new file mode 100644 index 0000000..87ccf31 --- /dev/null +++ b/chatdesk-ui/app/[locale]/page.tsx @@ -0,0 +1,51 @@ +"use client" + +import { useEffect, useState } from "react" +import { ChatbotUISVG } from "@/components/icons/chatbotui-svg" +import { IconArrowRight } from "@tabler/icons-react" +import { useTheme } from "next-themes" +import Link from "next/link" + +import HomeRedirector from "@/components/utility/home-redirector" +import { LanguageSwitcher } from '@/components/ui/language-switcher' + +import { useTranslation } from "react-i18next" + +export default function HomePage() { + const { theme } = useTheme() + const { t, i18n } = useTranslation() + + const [preferredLanguage, setPreferredLanguage] = useState('en') // 默认语言为 'en' + + // 根据 localStorage 或 cookie 设置 preferredLanguage + useEffect(() => { + const languageFromStorage = localStorage.getItem('preferred-language') || document.cookie.split('; ').find(row => row.startsWith('preferred-language='))?.split('=')[1]; + if (languageFromStorage) { + setPreferredLanguage(languageFromStorage); + // 更新 i18n 的语言设置 + i18n.changeLanguage(languageFromStorage); // 通过 i18n 更新默认语言 + } + }, []); + + return ( +
+ + + + +
+ +
+ +
{t("Company Name")}
+ + + {t("Clock In")} + + +
+ ) +} diff --git a/chatdesk-ui/app/[locale]/setup/page.tsx b/chatdesk-ui/app/[locale]/setup/page.tsx new file mode 100644 index 0000000..3b961e2 --- /dev/null +++ b/chatdesk-ui/app/[locale]/setup/page.tsx @@ -0,0 +1,292 @@ +'use client' + +import { ChatbotUIContext } from "@/context/context" +import { getProfileByUserId, updateProfile } from "@/db/profile" +import { + getHomeWorkspaceByUserId, + getWorkspacesByUserId +} from "@/db/workspaces" +import { + fetchHostedModels, + fetchOpenRouterModels +} from "@/lib/models/fetch-models" +import { supabase } from "@/lib/supabase/browser-client" +import { TablesUpdate } from "@/supabase/types" +import { useRouter } from "next/navigation" +import { useContext, useEffect, useState } from "react" +import { APIStep } from "../../../components/setup/api-step" +import { FinishStep } from "../../../components/setup/finish-step" +import { ProfileStep } from "../../../components/setup/profile-step" +import { + SETUP_STEP_COUNT, + StepContainer +} from "../../../components/setup/step-container" +import { useTranslation } from 'react-i18next' + +import { usePathname } from "next/navigation" +import i18nConfig from "@/i18nConfig" + + +export default function SetupPage() { + const { + profile, + setProfile, + setWorkspaces, + setSelectedWorkspace, + setEnvKeyMap, + setAvailableHostedModels, + setAvailableOpenRouterModels + } = useContext(ChatbotUIContext) + + const router = useRouter() + + + + const pathname = usePathname() // 获取当前路径 + const pathSegments = pathname.split("/").filter(Boolean) + const locales = i18nConfig.locales + const defaultLocale = i18nConfig.defaultLocale + + let locale: (typeof locales)[number] = defaultLocale + + const segment = pathSegments[0] as (typeof locales)[number] + + if (locales.includes(segment)) { + locale = segment + } + //const homePath = locale === defaultLocale ? "/" : `/${locale}` + const homePath = locale === defaultLocale ? "" : `/${locale}` + + + + + // // 提取当前路径中的 locale 部分 + // const locale = pathname.split("/")[1] || "en" // 获取路径中的 locale 部分,如果没有则默认为 "en" + + const { t } = useTranslation() + + const [loading, setLoading] = useState(true) + + const [currentStep, setCurrentStep] = useState(1) + + // Profile Step + const [displayName, setDisplayName] = useState("") + const [username, setUsername] = useState(profile?.username || "") + const [usernameAvailable, setUsernameAvailable] = useState(true) + + // API Step + const [useAzureOpenai, setUseAzureOpenai] = useState(false) + const [openaiAPIKey, setOpenaiAPIKey] = useState("") + const [openaiOrgID, setOpenaiOrgID] = useState("") + const [azureOpenaiAPIKey, setAzureOpenaiAPIKey] = useState("") + const [azureOpenaiEndpoint, setAzureOpenaiEndpoint] = useState("") + const [azureOpenai35TurboID, setAzureOpenai35TurboID] = useState("") + const [azureOpenai45TurboID, setAzureOpenai45TurboID] = useState("") + const [azureOpenai45VisionID, setAzureOpenai45VisionID] = useState("") + const [azureOpenaiEmbeddingsID, setAzureOpenaiEmbeddingsID] = useState("") + const [anthropicAPIKey, setAnthropicAPIKey] = useState("") + const [googleGeminiAPIKey, setGoogleGeminiAPIKey] = useState("") + const [mistralAPIKey, setMistralAPIKey] = useState("") + const [groqAPIKey, setGroqAPIKey] = useState("") + const [perplexityAPIKey, setPerplexityAPIKey] = useState("") + const [openrouterAPIKey, setOpenrouterAPIKey] = useState("") + + useEffect(() => { + ;(async () => { + const session = (await supabase.auth.getSession()).data.session + + if (!session) { + // 强制跳转到带有 locale 的 login 页面 + console.log("...........[setup/page.tsx]") + return router.push(`${homePath}/login`) + // return router.push(`/${locale}/login`) + } else { + const user = session.user + + const profile = await getProfileByUserId(user.id) + setProfile(profile) + setUsername(profile.username) + + if (!profile.has_onboarded) { + setLoading(false) + } else { + const data = await fetchHostedModels(profile) + + if (!data) return + + setEnvKeyMap(data.envKeyMap) + setAvailableHostedModels(data.hostedModels) + + if (profile["openrouter_api_key"] || data.envKeyMap["openrouter"]) { + const openRouterModels = await fetchOpenRouterModels() + if (!openRouterModels) return + setAvailableOpenRouterModels(openRouterModels) + } + + const homeWorkspaceId = await getHomeWorkspaceByUserId( + session.user.id + ) + return router.push(`${homePath}/${homeWorkspaceId}/chat`) + // return router.push(`/${locale}/${homeWorkspaceId}/chat`) + } + } + })() + }, []) + + const handleShouldProceed = (proceed: boolean) => { + if (proceed) { + if (currentStep === SETUP_STEP_COUNT) { + handleSaveSetupSetting() + } else { + setCurrentStep(currentStep + 1) + } + } else { + setCurrentStep(currentStep - 1) + } + } + + const handleSaveSetupSetting = async () => { + const session = (await supabase.auth.getSession()).data.session + if (!session) { + // return router.push(`/${locale}/login`) + return (`${homePath}/login`) + } + + const user = session.user + const profile = await getProfileByUserId(user.id) + + const updateProfilePayload: TablesUpdate<"profiles"> = { + ...profile, + has_onboarded: true, + display_name: displayName, + username, + openai_api_key: openaiAPIKey, + openai_organization_id: openaiOrgID, + anthropic_api_key: anthropicAPIKey, + google_gemini_api_key: googleGeminiAPIKey, + mistral_api_key: mistralAPIKey, + groq_api_key: groqAPIKey, + perplexity_api_key: perplexityAPIKey, + openrouter_api_key: openrouterAPIKey, + use_azure_openai: useAzureOpenai, + azure_openai_api_key: azureOpenaiAPIKey, + azure_openai_endpoint: azureOpenaiEndpoint, + azure_openai_35_turbo_id: azureOpenai35TurboID, + azure_openai_45_turbo_id: azureOpenai45TurboID, + azure_openai_45_vision_id: azureOpenai45VisionID, + azure_openai_embeddings_id: azureOpenaiEmbeddingsID + } + + const updatedProfile = await updateProfile(profile.id, updateProfilePayload) + setProfile(updatedProfile) + + const workspaces = await getWorkspacesByUserId(profile.user_id) + const homeWorkspace = workspaces.find(w => w.is_home) + + // There will always be a home workspace + setSelectedWorkspace(homeWorkspace!) + setWorkspaces(workspaces) + + return router.push(`${homePath}/${homeWorkspace?.id}/chat`) + // return router.push(`/${locale}/${homeWorkspace?.id}/chat`) + } + + const renderStep = (stepNum: number) => { + switch (stepNum) { + // Profile Step + case 1: + return ( + + + + ) + + // API Step + case 2: + return ( + + + + ) + + // Finish Step + case 3: + return ( + + + + ) + default: + return null + } + } + + if (loading) { + return null + } + + return ( +
+ {renderStep(currentStep)} +
+ ) +} diff --git a/chatdesk-ui/app/api/assistants/openai/route.ts b/chatdesk-ui/app/api/assistants/openai/route.ts new file mode 100644 index 0000000..7e16e52 --- /dev/null +++ b/chatdesk-ui/app/api/assistants/openai/route.ts @@ -0,0 +1,32 @@ +import { checkApiKey, getServerProfile } from "@/lib/server/server-chat-helpers" +import { ServerRuntime } from "next" +import OpenAI from "openai" + +export const runtime: ServerRuntime = "edge" + +export async function GET() { + try { + const profile = await getServerProfile() + + checkApiKey(profile.openai_api_key, "OpenAI") + + const openai = new OpenAI({ + apiKey: profile.openai_api_key || "", + organization: profile.openai_organization_id + }) + + const myAssistants = await openai.beta.assistants.list({ + limit: 100 + }) + + return new Response(JSON.stringify({ assistants: myAssistants.data }), { + status: 200 + }) + } catch (error: any) { + const errorMessage = error.error?.message || "An unexpected error occurred" + const errorCode = error.status || 500 + return new Response(JSON.stringify({ message: errorMessage }), { + status: errorCode + }) + } +} diff --git a/chatdesk-ui/app/api/chat/anthropic/route.ts b/chatdesk-ui/app/api/chat/anthropic/route.ts new file mode 100644 index 0000000..4f6242f --- /dev/null +++ b/chatdesk-ui/app/api/chat/anthropic/route.ts @@ -0,0 +1,111 @@ +import { CHAT_SETTING_LIMITS } from "@/lib/chat-setting-limits" +import { checkApiKey, getServerProfile } from "@/lib/server/server-chat-helpers" +import { getBase64FromDataURL, getMediaTypeFromDataURL } from "@/lib/utils" +import { ChatSettings } from "@/types" +import Anthropic from "@anthropic-ai/sdk" +import { AnthropicStream, StreamingTextResponse } from "ai" +import { NextRequest, NextResponse } from "next/server" + +export const runtime = "edge" + +export async function POST(request: NextRequest) { + const json = await request.json() + const { chatSettings, messages } = json as { + chatSettings: ChatSettings + messages: any[] + } + + try { + const profile = await getServerProfile() + + checkApiKey(profile.anthropic_api_key, "Anthropic") + + let ANTHROPIC_FORMATTED_MESSAGES: any = messages.slice(1) + + ANTHROPIC_FORMATTED_MESSAGES = ANTHROPIC_FORMATTED_MESSAGES?.map( + (message: any) => { + const messageContent = + typeof message?.content === "string" + ? [message.content] + : message?.content + + return { + ...message, + content: messageContent.map((content: any) => { + if (typeof content === "string") { + // Handle the case where content is a string + return { type: "text", text: content } + } else if ( + content?.type === "image_url" && + content?.image_url?.url?.length + ) { + return { + type: "image", + source: { + type: "base64", + media_type: getMediaTypeFromDataURL(content.image_url.url), + data: getBase64FromDataURL(content.image_url.url) + } + } + } else { + return content + } + }) + } + } + ) + + const anthropic = new Anthropic({ + apiKey: profile.anthropic_api_key || "" + }) + + try { + const response = await anthropic.messages.create({ + model: chatSettings.model, + messages: ANTHROPIC_FORMATTED_MESSAGES, + temperature: chatSettings.temperature, + system: messages[0].content, + max_tokens: + CHAT_SETTING_LIMITS[chatSettings.model].MAX_TOKEN_OUTPUT_LENGTH, + stream: true + }) + + try { + const stream = AnthropicStream(response) + return new StreamingTextResponse(stream) + } catch (error: any) { + console.error("Error parsing Anthropic API response:", error) + return new NextResponse( + JSON.stringify({ + message: + "An error occurred while parsing the Anthropic API response" + }), + { status: 500 } + ) + } + } catch (error: any) { + console.error("Error calling Anthropic API:", error) + return new NextResponse( + JSON.stringify({ + message: "An error occurred while calling the Anthropic API" + }), + { status: 500 } + ) + } + } catch (error: any) { + let errorMessage = error.message || "An unexpected error occurred" + const errorCode = error.status || 500 + + if (errorMessage.toLowerCase().includes("api key not found")) { + errorMessage = + "Anthropic API Key not found. Please set it in your profile settings." + } else if (errorCode === 401) { + errorMessage = + "Anthropic API Key is incorrect. Please fix it in your profile settings." + } + + return new NextResponse(JSON.stringify({ message: errorMessage }), { + status: errorCode + }) + } +} diff --git a/chatdesk-ui/app/api/chat/azure/route.ts b/chatdesk-ui/app/api/chat/azure/route.ts new file mode 100644 index 0000000..642eb77 --- /dev/null +++ b/chatdesk-ui/app/api/chat/azure/route.ts @@ -0,0 +1,72 @@ +import { checkApiKey, getServerProfile } from "@/lib/server/server-chat-helpers" +import { ChatAPIPayload } from "@/types" +import { OpenAIStream, StreamingTextResponse } from "ai" +import OpenAI from "openai" +import { ChatCompletionCreateParamsBase } from "openai/resources/chat/completions.mjs" + +export const runtime = "edge" + +export async function POST(request: Request) { + const json = await request.json() + const { chatSettings, messages } = json as ChatAPIPayload + + try { + const profile = await getServerProfile() + + checkApiKey(profile.azure_openai_api_key, "Azure OpenAI") + + const ENDPOINT = profile.azure_openai_endpoint + const KEY = profile.azure_openai_api_key + + let DEPLOYMENT_ID = "" + switch (chatSettings.model) { + case "gpt-3.5-turbo": + DEPLOYMENT_ID = profile.azure_openai_35_turbo_id || "" + break + case "gpt-4-turbo-preview": + DEPLOYMENT_ID = profile.azure_openai_45_turbo_id || "" + break + case "gpt-4-vision-preview": + DEPLOYMENT_ID = profile.azure_openai_45_vision_id || "" + break + default: + return new Response(JSON.stringify({ message: "Model not found" }), { + status: 400 + }) + } + + if (!ENDPOINT || !KEY || !DEPLOYMENT_ID) { + return new Response( + JSON.stringify({ message: "Azure resources not found" }), + { + status: 400 + } + ) + } + + const azureOpenai = new OpenAI({ + apiKey: KEY, + baseURL: `${ENDPOINT}/openai/deployments/${DEPLOYMENT_ID}`, + defaultQuery: { "api-version": "2023-12-01-preview" }, + defaultHeaders: { "api-key": KEY } + }) + + const response = await azureOpenai.chat.completions.create({ + model: DEPLOYMENT_ID as ChatCompletionCreateParamsBase["model"], + messages: messages as ChatCompletionCreateParamsBase["messages"], + temperature: chatSettings.temperature, + max_tokens: chatSettings.model === "gpt-4-vision-preview" ? 4096 : null, // TODO: Fix + stream: true + }) + + const stream = OpenAIStream(response) + + return new StreamingTextResponse(stream) + } catch (error: any) { + const errorMessage = error.error?.message || "An unexpected error occurred" + const errorCode = error.status || 500 + return new Response(JSON.stringify({ message: errorMessage }), { + status: errorCode + }) + } +} diff --git a/chatdesk-ui/app/api/chat/custom/route.ts b/chatdesk-ui/app/api/chat/custom/route.ts new file mode 100644 index 0000000..2c8e7c8 --- /dev/null +++ b/chatdesk-ui/app/api/chat/custom/route.ts @@ -0,0 +1,66 @@ +import { Database } from "@/supabase/types" +import { ChatSettings } from "@/types" +import { createClient } from "@supabase/supabase-js" +import { OpenAIStream, StreamingTextResponse } from "ai" +import { ServerRuntime } from "next" +import OpenAI from "openai" +import { ChatCompletionCreateParamsBase } from "openai/resources/chat/completions.mjs" + +export const runtime: ServerRuntime = "edge" + +export async function POST(request: Request) { + const json = await request.json() + const { chatSettings, messages, customModelId } = json as { + chatSettings: ChatSettings + messages: any[] + customModelId: string + } + + try { + const supabaseAdmin = createClient( + process.env.NEXT_PUBLIC_SUPABASE_URL!, + process.env.SUPABASE_SERVICE_ROLE_KEY! + ) + + const { data: customModel, error } = await supabaseAdmin + .from("models") + .select("*") + .eq("id", customModelId) + .single() + + if (!customModel) { + throw new Error(error.message) + } + + const custom = new OpenAI({ + apiKey: customModel.api_key || "", + baseURL: customModel.base_url + }) + + const response = await custom.chat.completions.create({ + model: chatSettings.model as ChatCompletionCreateParamsBase["model"], + messages: messages as ChatCompletionCreateParamsBase["messages"], + temperature: chatSettings.temperature, + stream: true + }) + + const stream = OpenAIStream(response) + + return new StreamingTextResponse(stream) + } catch (error: any) { + let errorMessage = error.message || "An unexpected error occurred" + const errorCode = error.status || 500 + + if (errorMessage.toLowerCase().includes("api key not found")) { + errorMessage = + "Custom API Key not found. Please set it in your profile settings." + } else if (errorMessage.toLowerCase().includes("incorrect api key")) { + errorMessage = + "Custom API Key is incorrect. Please fix it in your profile settings." + } + + return new Response(JSON.stringify({ message: errorMessage }), { + status: errorCode + }) + } +} diff --git a/chatdesk-ui/app/api/chat/google/route.ts b/chatdesk-ui/app/api/chat/google/route.ts new file mode 100644 index 0000000..ad79139 --- /dev/null +++ b/chatdesk-ui/app/api/chat/google/route.ts @@ -0,0 +1,64 @@ +import { checkApiKey, getServerProfile } from "@/lib/server/server-chat-helpers" +import { ChatSettings } from "@/types" +import { GoogleGenerativeAI } from "@google/generative-ai" + +export const runtime = "edge" + +export async function POST(request: Request) { + const json = await request.json() + const { chatSettings, messages } = json as { + chatSettings: ChatSettings + messages: any[] + } + + try { + const profile = await getServerProfile() + + checkApiKey(profile.google_gemini_api_key, "Google") + + const genAI = new GoogleGenerativeAI(profile.google_gemini_api_key || "") + const googleModel = genAI.getGenerativeModel({ model: chatSettings.model }) + + const lastMessage = messages.pop() + + const chat = googleModel.startChat({ + history: messages, + generationConfig: { + temperature: chatSettings.temperature + } + }) + + const response = await chat.sendMessageStream(lastMessage.parts) + + const encoder = new TextEncoder() + const readableStream = new ReadableStream({ + async start(controller) { + for await (const chunk of response.stream) { + const chunkText = chunk.text() + controller.enqueue(encoder.encode(chunkText)) + } + controller.close() + } + }) + + return new Response(readableStream, { + headers: { "Content-Type": "text/plain" } + }) + + } catch (error: any) { + let errorMessage = error.message || "An unexpected error occurred" + const errorCode = error.status || 500 + + if (errorMessage.toLowerCase().includes("api key not found")) { + errorMessage = + "Google Gemini API Key not found. Please set it in your profile settings." + } else if (errorMessage.toLowerCase().includes("api key not valid")) { + errorMessage = + "Google Gemini API Key is incorrect. Please fix it in your profile settings." + } + + return new Response(JSON.stringify({ message: errorMessage }), { + status: errorCode + }) + } +} diff --git a/chatdesk-ui/app/api/chat/groq/route.ts b/chatdesk-ui/app/api/chat/groq/route.ts new file mode 100644 index 0000000..653de00 --- /dev/null +++ b/chatdesk-ui/app/api/chat/groq/route.ts @@ -0,0 +1,55 @@ +import { CHAT_SETTING_LIMITS } from "@/lib/chat-setting-limits" +import { checkApiKey, getServerProfile } from "@/lib/server/server-chat-helpers" +import { ChatSettings } from "@/types" +import { OpenAIStream, StreamingTextResponse } from "ai" +import OpenAI from "openai" + +export const runtime = "edge" +export async function POST(request: Request) { + const json = await request.json() + const { chatSettings, messages } = json as { + chatSettings: ChatSettings + messages: any[] + } + + try { + const profile = await getServerProfile() + + checkApiKey(profile.groq_api_key, "G") + + // Groq is compatible with the OpenAI SDK + const groq = new OpenAI({ + apiKey: profile.groq_api_key || "", + baseURL: "https://api.groq.com/openai/v1" + }) + + const response = await groq.chat.completions.create({ + model: chatSettings.model, + messages, + max_tokens: + CHAT_SETTING_LIMITS[chatSettings.model].MAX_TOKEN_OUTPUT_LENGTH, + stream: true + }) + + // Convert the response into a friendly text-stream. + const stream = OpenAIStream(response) + + // Respond with the stream + return new StreamingTextResponse(stream) + } catch (error: any) { + let errorMessage = error.message || "An unexpected error occurred" + const errorCode = error.status || 500 + + if (errorMessage.toLowerCase().includes("api key not found")) { + errorMessage = + "Groq API Key not found. Please set it in your profile settings." + } else if (errorCode === 401) { + errorMessage = + "Groq API Key is incorrect. Please fix it in your profile settings." + } + + return new Response(JSON.stringify({ message: errorMessage }), { + status: errorCode + }) + } +} diff --git a/chatdesk-ui/app/api/chat/mistral/route.ts b/chatdesk-ui/app/api/chat/mistral/route.ts new file mode 100644 index 0000000..5153ca6 --- /dev/null +++ b/chatdesk-ui/app/api/chat/mistral/route.ts @@ -0,0 +1,56 @@ +import { CHAT_SETTING_LIMITS } from "@/lib/chat-setting-limits" +import { checkApiKey, getServerProfile } from "@/lib/server/server-chat-helpers" +import { ChatSettings } from "@/types" +import { OpenAIStream, StreamingTextResponse } from "ai" +import OpenAI from "openai" + +export const runtime = "edge" + +export async function POST(request: Request) { + const json = await request.json() + const { chatSettings, messages } = json as { + chatSettings: ChatSettings + messages: any[] + } + + try { + const profile = await getServerProfile() + + checkApiKey(profile.mistral_api_key, "Mistral") + + // Mistral is compatible the OpenAI SDK + const mistral = new OpenAI({ + apiKey: profile.mistral_api_key || "", + baseURL: "https://api.mistral.ai/v1" + }) + + const response = await mistral.chat.completions.create({ + model: chatSettings.model, + messages, + max_tokens: + CHAT_SETTING_LIMITS[chatSettings.model].MAX_TOKEN_OUTPUT_LENGTH, + stream: true + }) + + // Convert the response into a friendly text-stream. + const stream = OpenAIStream(response) + + // Respond with the stream + return new StreamingTextResponse(stream) + } catch (error: any) { + let errorMessage = error.message || "An unexpected error occurred" + const errorCode = error.status || 500 + + if (errorMessage.toLowerCase().includes("api key not found")) { + errorMessage = + "Mistral API Key not found. Please set it in your profile settings." + } else if (errorCode === 401) { + errorMessage = + "Mistral API Key is incorrect. Please fix it in your profile settings." + } + + return new Response(JSON.stringify({ message: errorMessage }), { + status: errorCode + }) + } +} diff --git a/chatdesk-ui/app/api/chat/openai/route.ts b/chatdesk-ui/app/api/chat/openai/route.ts new file mode 100644 index 0000000..a0f8ad0 --- /dev/null +++ b/chatdesk-ui/app/api/chat/openai/route.ts @@ -0,0 +1,58 @@ +import { checkApiKey, getServerProfile } from "@/lib/server/server-chat-helpers" +import { ChatSettings } from "@/types" +import { OpenAIStream, StreamingTextResponse } from "ai" +import { ServerRuntime } from "next" +import OpenAI from "openai" +import { ChatCompletionCreateParamsBase } from "openai/resources/chat/completions.mjs" + +export const runtime: ServerRuntime = "edge" + +export async function POST(request: Request) { + const json = await request.json() + const { chatSettings, messages } = json as { + chatSettings: ChatSettings + messages: any[] + } + + try { + const profile = await getServerProfile() + + checkApiKey(profile.openai_api_key, "OpenAI") + + const openai = new OpenAI({ + apiKey: profile.openai_api_key || "", + organization: profile.openai_organization_id + }) + + const response = await openai.chat.completions.create({ + model: chatSettings.model as ChatCompletionCreateParamsBase["model"], + messages: messages as ChatCompletionCreateParamsBase["messages"], + temperature: chatSettings.temperature, + max_tokens: + chatSettings.model === "gpt-4-vision-preview" || + chatSettings.model === "gpt-4o" + ? 4096 + : null, // TODO: Fix + stream: true + }) + + const stream = OpenAIStream(response) + + return new StreamingTextResponse(stream) + } catch (error: any) { + let errorMessage = error.message || "An unexpected error occurred" + const errorCode = error.status || 500 + + if (errorMessage.toLowerCase().includes("api key not found")) { + errorMessage = + "OpenAI API Key not found. Please set it in your profile settings." + } else if (errorMessage.toLowerCase().includes("incorrect api key")) { + errorMessage = + "OpenAI API Key is incorrect. Please fix it in your profile settings." + } + + return new Response(JSON.stringify({ message: errorMessage }), { + status: errorCode + }) + } +} diff --git a/chatdesk-ui/app/api/chat/openrouter/route.ts b/chatdesk-ui/app/api/chat/openrouter/route.ts new file mode 100644 index 0000000..34a8a74 --- /dev/null +++ b/chatdesk-ui/app/api/chat/openrouter/route.ts @@ -0,0 +1,51 @@ +import { checkApiKey, getServerProfile } from "@/lib/server/server-chat-helpers" +import { ChatSettings } from "@/types" +import { OpenAIStream, StreamingTextResponse } from "ai" +import { ServerRuntime } from "next" +import OpenAI from "openai" +import { ChatCompletionCreateParamsBase } from "openai/resources/chat/completions.mjs" + +export const runtime: ServerRuntime = "edge" + +export async function POST(request: Request) { + const json = await request.json() + const { chatSettings, messages } = json as { + chatSettings: ChatSettings + messages: any[] + } + + try { + const profile = await getServerProfile() + + checkApiKey(profile.openrouter_api_key, "OpenRouter") + + const openai = new OpenAI({ + apiKey: profile.openrouter_api_key || "", + baseURL: "https://openrouter.ai/api/v1" + }) + + const response = await openai.chat.completions.create({ + model: chatSettings.model as ChatCompletionCreateParamsBase["model"], + messages: messages as ChatCompletionCreateParamsBase["messages"], + temperature: chatSettings.temperature, + max_tokens: undefined, + stream: true + }) + + const stream = OpenAIStream(response) + + return new StreamingTextResponse(stream) + } catch (error: any) { + let errorMessage = error.message || "An unexpected error occurred" + const errorCode = error.status || 500 + + if (errorMessage.toLowerCase().includes("api key not found")) { + errorMessage = + "OpenRouter API Key not found. Please set it in your profile settings." + } + + return new Response(JSON.stringify({ message: errorMessage }), { + status: errorCode + }) + } +} diff --git a/chatdesk-ui/app/api/chat/perplexity/route.ts b/chatdesk-ui/app/api/chat/perplexity/route.ts new file mode 100644 index 0000000..db700a2 --- /dev/null +++ b/chatdesk-ui/app/api/chat/perplexity/route.ts @@ -0,0 +1,51 @@ +import { checkApiKey, getServerProfile } from "@/lib/server/server-chat-helpers" +import { ChatSettings } from "@/types" +import { OpenAIStream, StreamingTextResponse } from "ai" +import OpenAI from "openai" + +export const runtime = "edge" + +export async function POST(request: Request) { + const json = await request.json() + const { chatSettings, messages } = json as { + chatSettings: ChatSettings + messages: any[] + } + + try { + const profile = await getServerProfile() + + checkApiKey(profile.perplexity_api_key, "Perplexity") + + // Perplexity is compatible the OpenAI SDK + const perplexity = new OpenAI({ + apiKey: profile.perplexity_api_key || "", + baseURL: "https://api.perplexity.ai/" + }) + + const response = await perplexity.chat.completions.create({ + model: chatSettings.model, + messages, + stream: true + }) + + const stream = OpenAIStream(response) + + return new StreamingTextResponse(stream) + } catch (error: any) { + let errorMessage = error.message || "An unexpected error occurred" + const errorCode = error.status || 500 + + if (errorMessage.toLowerCase().includes("api key not found")) { + errorMessage = + "Perplexity API Key not found. Please set it in your profile settings." + } else if (errorCode === 401) { + errorMessage = + "Perplexity API Key is incorrect. Please fix it in your profile settings." + } + + return new Response(JSON.stringify({ message: errorMessage }), { + status: errorCode + }) + } +} diff --git a/chatdesk-ui/app/api/chat/tools/route.ts b/chatdesk-ui/app/api/chat/tools/route.ts new file mode 100644 index 0000000..752df25 --- /dev/null +++ b/chatdesk-ui/app/api/chat/tools/route.ts @@ -0,0 +1,218 @@ +import { openapiToFunctions } from "@/lib/openapi-conversion" +import { checkApiKey, getServerProfile } from "@/lib/server/server-chat-helpers" +import { Tables } from "@/supabase/types" +import { ChatSettings } from "@/types" +import { OpenAIStream, StreamingTextResponse } from "ai" +import OpenAI from "openai" +import { ChatCompletionCreateParamsBase } from "openai/resources/chat/completions.mjs" + +export async function POST(request: Request) { + const json = await request.json() + const { chatSettings, messages, selectedTools } = json as { + chatSettings: ChatSettings + messages: any[] + selectedTools: Tables<"tools">[] + } + + try { + const profile = await getServerProfile() + + checkApiKey(profile.openai_api_key, "OpenAI") + + const openai = new OpenAI({ + apiKey: profile.openai_api_key || "", + organization: profile.openai_organization_id + }) + + let allTools: OpenAI.Chat.Completions.ChatCompletionTool[] = [] + let allRouteMaps = {} + let schemaDetails = [] + + for (const selectedTool of selectedTools) { + try { + const convertedSchema = await openapiToFunctions( + JSON.parse(selectedTool.schema as string) + ) + const tools = convertedSchema.functions || [] + allTools = allTools.concat(tools) + + const routeMap = convertedSchema.routes.reduce( + (map: Record, route) => { + map[route.path.replace(/{(\w+)}/g, ":$1")] = route.operationId + return map + }, + {} + ) + + allRouteMaps = { ...allRouteMaps, ...routeMap } + + schemaDetails.push({ + title: convertedSchema.info.title, + description: convertedSchema.info.description, + url: convertedSchema.info.server, + headers: selectedTool.custom_headers, + routeMap, + requestInBody: convertedSchema.routes[0].requestInBody + }) + } catch (error: any) { + console.error("Error converting schema", error) + } + } + + const firstResponse = await openai.chat.completions.create({ + model: chatSettings.model as ChatCompletionCreateParamsBase["model"], + messages, + tools: allTools.length > 0 ? allTools : undefined + }) + + const message = firstResponse.choices[0].message + messages.push(message) + const toolCalls = message.tool_calls || [] + + if (toolCalls.length === 0) { + return new Response(message.content, { + headers: { + "Content-Type": "application/json" + } + }) + } + + if (toolCalls.length > 0) { + for (const toolCall of toolCalls) { + const functionCall = toolCall.function + const functionName = functionCall.name + const argumentsString = toolCall.function.arguments.trim() + const parsedArgs = JSON.parse(argumentsString) + + // Find the schema detail that contains the function name + const schemaDetail = schemaDetails.find(detail => + Object.values(detail.routeMap).includes(functionName) + ) + + if (!schemaDetail) { + throw new Error(`Function ${functionName} not found in any schema`) + } + + const pathTemplate = Object.keys(schemaDetail.routeMap).find( + key => schemaDetail.routeMap[key] === functionName + ) + + if (!pathTemplate) { + throw new Error(`Path for function ${functionName} not found`) + } + + const path = pathTemplate.replace(/:(\w+)/g, (_, paramName) => { + const value = parsedArgs.parameters[paramName] + if (!value) { + throw new Error( + `Parameter ${paramName} not found for function ${functionName}` + ) + } + return encodeURIComponent(value) + }) + + if (!path) { + throw new Error(`Path for function ${functionName} not found`) + } + + // Determine if the request should be in the body or as a query + const isRequestInBody = schemaDetail.requestInBody + let data = {} + + if (isRequestInBody) { + // If the type is set to body + let headers = { + "Content-Type": "application/json" + } + + // Check if custom headers are set + const customHeaders = schemaDetail.headers // Moved this line up to the loop + // Check if custom headers are set and are of type string + if (customHeaders && typeof customHeaders === "string") { + let parsedCustomHeaders = JSON.parse(customHeaders) as Record< + string, + string + > + + headers = { + ...headers, + ...parsedCustomHeaders + } + } + + const fullUrl = schemaDetail.url + path + + const bodyContent = parsedArgs.requestBody || parsedArgs + + const requestInit = { + method: "POST", + headers, + body: JSON.stringify(bodyContent) // Use the extracted requestBody or the entire parsedArgs + } + + const response = await fetch(fullUrl, requestInit) + + if (!response.ok) { + data = { + error: response.statusText + } + } else { + data = await response.json() + } + } else { + // If the type is set to query + const queryParams = new URLSearchParams( + parsedArgs.parameters + ).toString() + const fullUrl = + schemaDetail.url + path + (queryParams ? "?" + queryParams : "") + + let headers = {} + + // Check if custom headers are set + const customHeaders = schemaDetail.headers + if (customHeaders && typeof customHeaders === "string") { + headers = JSON.parse(customHeaders) + } + + const response = await fetch(fullUrl, { + method: "GET", + headers: headers + }) + + if (!response.ok) { + data = { + error: response.statusText + } + } else { + data = await response.json() + } + } + + messages.push({ + tool_call_id: toolCall.id, + role: "tool", + name: functionName, + content: JSON.stringify(data) + }) + } + } + + const secondResponse = await openai.chat.completions.create({ + model: chatSettings.model as ChatCompletionCreateParamsBase["model"], + messages, + stream: true + }) + + const stream = OpenAIStream(secondResponse) + + return new StreamingTextResponse(stream) + } catch (error: any) { + console.error(error) + const errorMessage = error.error?.message || "An unexpected error occurred" + const errorCode = error.status || 500 + return new Response(JSON.stringify({ message: errorMessage }), { + status: errorCode + }) + } +} diff --git a/chatdesk-ui/app/api/command/route.ts b/chatdesk-ui/app/api/command/route.ts new file mode 100644 index 0000000..b10df7b --- /dev/null +++ b/chatdesk-ui/app/api/command/route.ts @@ -0,0 +1,54 @@ +import { CHAT_SETTING_LIMITS } from "@/lib/chat-setting-limits" +import { checkApiKey, getServerProfile } from "@/lib/server/server-chat-helpers" +import OpenAI from "openai" + +export const runtime = "edge" + +export async function POST(request: Request) { + const json = await request.json() + const { input } = json as { + input: string + } + + try { + const profile = await getServerProfile() + + checkApiKey(profile.openai_api_key, "OpenAI") + + const openai = new OpenAI({ + apiKey: profile.openai_api_key || "", + organization: profile.openai_organization_id + }) + + const response = await openai.chat.completions.create({ + model: "gpt-4-1106-preview", + messages: [ + { + role: "system", + content: "Respond to the user." + }, + { + role: "user", + content: input + } + ], + temperature: 0, + max_tokens: + CHAT_SETTING_LIMITS["gpt-4-turbo-preview"].MAX_TOKEN_OUTPUT_LENGTH + // response_format: { type: "json_object" } + // stream: true + }) + + const content = response.choices[0].message.content + + return new Response(JSON.stringify({ content }), { + status: 200 + }) + } catch (error: any) { + const errorMessage = error.error?.message || "An unexpected error occurred" + const errorCode = error.status || 500 + return new Response(JSON.stringify({ message: errorMessage }), { + status: errorCode + }) + } +} diff --git a/chatdesk-ui/app/api/keys/route.ts b/chatdesk-ui/app/api/keys/route.ts new file mode 100644 index 0000000..c0f8e38 --- /dev/null +++ b/chatdesk-ui/app/api/keys/route.ts @@ -0,0 +1,38 @@ +import { isUsingEnvironmentKey } from "@/lib/envs" +import { createResponse } from "@/lib/server/server-utils" +import { EnvKey } from "@/types/key-type" +import { VALID_ENV_KEYS } from "@/types/valid-keys" + +export async function GET() { + const envKeyMap: Record = { + azure: VALID_ENV_KEYS.AZURE_OPENAI_API_KEY, + openai: VALID_ENV_KEYS.OPENAI_API_KEY, + google: VALID_ENV_KEYS.GOOGLE_GEMINI_API_KEY, + anthropic: VALID_ENV_KEYS.ANTHROPIC_API_KEY, + mistral: VALID_ENV_KEYS.MISTRAL_API_KEY, + groq: VALID_ENV_KEYS.GROQ_API_KEY, + perplexity: VALID_ENV_KEYS.PERPLEXITY_API_KEY, + openrouter: VALID_ENV_KEYS.OPENROUTER_API_KEY, + + openai_organization_id: VALID_ENV_KEYS.OPENAI_ORGANIZATION_ID, + + azure_openai_endpoint: VALID_ENV_KEYS.AZURE_OPENAI_ENDPOINT, + azure_gpt_35_turbo_name: VALID_ENV_KEYS.AZURE_GPT_35_TURBO_NAME, + azure_gpt_45_vision_name: VALID_ENV_KEYS.AZURE_GPT_45_VISION_NAME, + azure_gpt_45_turbo_name: VALID_ENV_KEYS.AZURE_GPT_45_TURBO_NAME, + azure_embeddings_name: VALID_ENV_KEYS.AZURE_EMBEDDINGS_NAME + } + + const isUsingEnvKeyMap = Object.keys(envKeyMap).reduce< + Record + >((acc, provider) => { + const key = envKeyMap[provider] + + if (key) { + acc[provider] = isUsingEnvironmentKey(key as EnvKey) + } + return acc + }, {}) + + return createResponse({ isUsingEnvKeyMap }, 200) +} diff --git a/chatdesk-ui/app/api/retrieval/process/docx/route.ts b/chatdesk-ui/app/api/retrieval/process/docx/route.ts new file mode 100644 index 0000000..cea3d7a --- /dev/null +++ b/chatdesk-ui/app/api/retrieval/process/docx/route.ts @@ -0,0 +1,121 @@ +import { generateLocalEmbedding } from "@/lib/generate-local-embedding" +import { processDocX } from "@/lib/retrieval/processing" +import { checkApiKey, getServerProfile } from "@/lib/server/server-chat-helpers" +import { Database } from "@/supabase/types" +import { FileItemChunk } from "@/types" +import { createClient } from "@supabase/supabase-js" +import { NextResponse } from "next/server" +import OpenAI from "openai" + +export async function POST(req: Request) { + const json = await req.json() + const { text, fileId, embeddingsProvider, fileExtension } = json as { + text: string + fileId: string + embeddingsProvider: "openai" | "local" + fileExtension: string + } + + try { + const supabaseAdmin = createClient( + process.env.NEXT_PUBLIC_SUPABASE_URL!, + process.env.SUPABASE_SERVICE_ROLE_KEY! + ) + + const profile = await getServerProfile() + + if (embeddingsProvider === "openai") { + if (profile.use_azure_openai) { + checkApiKey(profile.azure_openai_api_key, "Azure OpenAI") + } else { + checkApiKey(profile.openai_api_key, "OpenAI") + } + } + + let chunks: FileItemChunk[] = [] + + switch (fileExtension) { + case "docx": + chunks = await processDocX(text) + break + default: + return new NextResponse("Unsupported file type", { + status: 400 + }) + } + + let embeddings: any = [] + + let openai + if (profile.use_azure_openai) { + openai = new OpenAI({ + apiKey: profile.azure_openai_api_key || "", + baseURL: `${profile.azure_openai_endpoint}/openai/deployments/${profile.azure_openai_embeddings_id}`, + defaultQuery: { "api-version": "2023-12-01-preview" }, + defaultHeaders: { "api-key": profile.azure_openai_api_key } + }) + } else { + openai = new OpenAI({ + apiKey: profile.openai_api_key || "", + organization: profile.openai_organization_id + }) + } + + if (embeddingsProvider === "openai") { + const response = await openai.embeddings.create({ + model: "text-embedding-3-small", + input: chunks.map(chunk => chunk.content) + }) + + embeddings = response.data.map((item: any) => { + return item.embedding + }) + } else if (embeddingsProvider === "local") { + const embeddingPromises = chunks.map(async chunk => { + try { + return await generateLocalEmbedding(chunk.content) + } catch (error) { + console.error(`Error generating embedding for chunk: ${chunk}`, error) + return null + } + }) + + embeddings = await Promise.all(embeddingPromises) + } + + const file_items = chunks.map((chunk, index) => ({ + file_id: fileId, + user_id: profile.user_id, + content: chunk.content, + tokens: chunk.tokens, + openai_embedding: + embeddingsProvider === "openai" + ? ((embeddings[index] || null) as any) + : null, + local_embedding: + embeddingsProvider === "local" + ? ((embeddings[index] || null) as any) + : null + })) + + await supabaseAdmin.from("file_items").upsert(file_items) + + const totalTokens = file_items.reduce((acc, item) => acc + item.tokens, 0) + + await supabaseAdmin + .from("files") + .update({ tokens: totalTokens }) + .eq("id", fileId) + + return new NextResponse("Embed Successful", { + status: 200 + }) + } catch (error: any) { + console.error(error) + const errorMessage = error.error?.message || "An unexpected error occurred" + const errorCode = error.status || 500 + return new Response(JSON.stringify({ message: errorMessage }), { + status: errorCode + }) + } +} diff --git a/chatdesk-ui/app/api/retrieval/process/route.ts b/chatdesk-ui/app/api/retrieval/process/route.ts new file mode 100644 index 0000000..f0221aa --- /dev/null +++ b/chatdesk-ui/app/api/retrieval/process/route.ts @@ -0,0 +1,175 @@ +import { generateLocalEmbedding } from "@/lib/generate-local-embedding" +import { + processCSV, + processJSON, + processMarkdown, + processPdf, + processTxt +} from "@/lib/retrieval/processing" +import { checkApiKey, getServerProfile } from "@/lib/server/server-chat-helpers" +import { Database } from "@/supabase/types" +import { FileItemChunk } from "@/types" +import { createClient } from "@supabase/supabase-js" +import { NextResponse } from "next/server" +import OpenAI from "openai" + +export async function POST(req: Request) { + try { + const supabaseAdmin = createClient( + process.env.NEXT_PUBLIC_SUPABASE_URL!, + process.env.SUPABASE_SERVICE_ROLE_KEY! + ) + + const profile = await getServerProfile() + + const formData = await req.formData() + + const file_id = formData.get("file_id") as string + const embeddingsProvider = formData.get("embeddingsProvider") as string + + const { data: fileMetadata, error: metadataError } = await supabaseAdmin + .from("files") + .select("*") + .eq("id", file_id) + .single() + + if (metadataError) { + throw new Error( + `Failed to retrieve file metadata: ${metadataError.message}` + ) + } + + if (!fileMetadata) { + throw new Error("File not found") + } + + if (fileMetadata.user_id !== profile.user_id) { + throw new Error("Unauthorized") + } + + const { data: file, error: fileError } = await supabaseAdmin.storage + .from("files") + .download(fileMetadata.file_path) + + if (fileError) + throw new Error(`Failed to retrieve file: ${fileError.message}`) + + const fileBuffer = Buffer.from(await file.arrayBuffer()) + const blob = new Blob([fileBuffer]) + const fileExtension = fileMetadata.name.split(".").pop()?.toLowerCase() + + if (embeddingsProvider === "openai") { + try { + if (profile.use_azure_openai) { + checkApiKey(profile.azure_openai_api_key, "Azure OpenAI") + } else { + checkApiKey(profile.openai_api_key, "OpenAI") + } + } catch (error: any) { + error.message = + error.message + + ", make sure it is configured or else use local embeddings" + throw error + } + } + + let chunks: FileItemChunk[] = [] + + switch (fileExtension) { + case "csv": + chunks = await processCSV(blob) + break + case "json": + chunks = await processJSON(blob) + break + case "md": + chunks = await processMarkdown(blob) + break + case "pdf": + chunks = await processPdf(blob) + break + case "txt": + chunks = await processTxt(blob) + break + default: + return new NextResponse("Unsupported file type", { + status: 400 + }) + } + + let embeddings: any = [] + + let openai + if (profile.use_azure_openai) { + openai = new OpenAI({ + apiKey: profile.azure_openai_api_key || "", + baseURL: `${profile.azure_openai_endpoint}/openai/deployments/${profile.azure_openai_embeddings_id}`, + defaultQuery: { "api-version": "2023-12-01-preview" }, + defaultHeaders: { "api-key": profile.azure_openai_api_key } + }) + } else { + openai = new OpenAI({ + apiKey: profile.openai_api_key || "", + organization: profile.openai_organization_id + }) + } + + if (embeddingsProvider === "openai") { + const response = await openai.embeddings.create({ + model: "text-embedding-3-small", + input: chunks.map(chunk => chunk.content) + }) + + embeddings = response.data.map((item: any) => { + return item.embedding + }) + } else if (embeddingsProvider === "local") { + const embeddingPromises = chunks.map(async chunk => { + try { + return await generateLocalEmbedding(chunk.content) + } catch (error) { + console.error(`Error generating embedding for chunk: ${chunk}`, error) + + return null + } + }) + + embeddings = await Promise.all(embeddingPromises) + } + + const file_items = chunks.map((chunk, index) => ({ + file_id, + user_id: profile.user_id, + content: chunk.content, + tokens: chunk.tokens, + openai_embedding: + embeddingsProvider === "openai" + ? ((embeddings[index] || null) as any) + : null, + local_embedding: + embeddingsProvider === "local" + ? ((embeddings[index] || null) as any) + : null + })) + + await supabaseAdmin.from("file_items").upsert(file_items) + + const totalTokens = file_items.reduce((acc, item) => acc + item.tokens, 0) + + await supabaseAdmin + .from("files") + .update({ tokens: totalTokens }) + .eq("id", file_id) + + return new NextResponse("Embed Successful", { + status: 200 + }) + } catch (error: any) { + console.log(`Error in retrieval/process: ${error.stack}`) + const errorMessage = error?.message || "An unexpected error occurred" + const errorCode = error.status || 500 + return new Response(JSON.stringify({ message: errorMessage }), { + status: errorCode + }) + } +} diff --git a/chatdesk-ui/app/api/retrieval/retrieve/route.ts b/chatdesk-ui/app/api/retrieval/retrieve/route.ts new file mode 100644 index 0000000..9c2755a --- /dev/null +++ b/chatdesk-ui/app/api/retrieval/retrieve/route.ts @@ -0,0 +1,102 @@ +import { generateLocalEmbedding } from "@/lib/generate-local-embedding" +import { checkApiKey, getServerProfile } from "@/lib/server/server-chat-helpers" +import { Database } from "@/supabase/types" +import { createClient } from "@supabase/supabase-js" +import OpenAI from "openai" + +export async function POST(request: Request) { + const json = await request.json() + const { userInput, fileIds, embeddingsProvider, sourceCount } = json as { + userInput: string + fileIds: string[] + embeddingsProvider: "openai" | "local" + sourceCount: number + } + + const uniqueFileIds = [...new Set(fileIds)] + + try { + const supabaseAdmin = createClient( + process.env.NEXT_PUBLIC_SUPABASE_URL!, + process.env.SUPABASE_SERVICE_ROLE_KEY! + ) + + const profile = await getServerProfile() + + if (embeddingsProvider === "openai") { + if (profile.use_azure_openai) { + checkApiKey(profile.azure_openai_api_key, "Azure OpenAI") + } else { + checkApiKey(profile.openai_api_key, "OpenAI") + } + } + + let chunks: any[] = [] + + let openai + if (profile.use_azure_openai) { + openai = new OpenAI({ + apiKey: profile.azure_openai_api_key || "", + baseURL: `${profile.azure_openai_endpoint}/openai/deployments/${profile.azure_openai_embeddings_id}`, + defaultQuery: { "api-version": "2023-12-01-preview" }, + defaultHeaders: { "api-key": profile.azure_openai_api_key } + }) + } else { + openai = new OpenAI({ + apiKey: profile.openai_api_key || "", + organization: profile.openai_organization_id + }) + } + + if (embeddingsProvider === "openai") { + const response = await openai.embeddings.create({ + model: "text-embedding-3-small", + input: userInput + }) + + const openaiEmbedding = response.data.map(item => item.embedding)[0] + + const { data: openaiFileItems, error: openaiError } = + await supabaseAdmin.rpc("match_file_items_openai", { + query_embedding: openaiEmbedding as any, + match_count: sourceCount, + file_ids: uniqueFileIds + }) + + if (openaiError) { + throw openaiError + } + + chunks = openaiFileItems + } else if (embeddingsProvider === "local") { + const localEmbedding = await generateLocalEmbedding(userInput) + + const { data: localFileItems, error: localFileItemsError } = + await supabaseAdmin.rpc("match_file_items_local", { + query_embedding: localEmbedding as any, + match_count: sourceCount, + file_ids: uniqueFileIds + }) + + if (localFileItemsError) { + throw localFileItemsError + } + + chunks = localFileItems + } + + const mostSimilarChunks = chunks?.sort( + (a, b) => b.similarity - a.similarity + ) + + return new Response(JSON.stringify({ results: mostSimilarChunks }), { + status: 200 + }) + } catch (error: any) { + const errorMessage = error.error?.message || "An unexpected error occurred" + const errorCode = error.status || 500 + return new Response(JSON.stringify({ message: errorMessage }), { + status: errorCode + }) + } +} diff --git a/chatdesk-ui/app/api/username/available/route.ts b/chatdesk-ui/app/api/username/available/route.ts new file mode 100644 index 0000000..bf00ee0 --- /dev/null +++ b/chatdesk-ui/app/api/username/available/route.ts @@ -0,0 +1,37 @@ +import { Database } from "@/supabase/types" +import { createClient } from "@supabase/supabase-js" + +export const runtime = "edge" + +export async function POST(request: Request) { + const json = await request.json() + const { username } = json as { + username: string + } + + try { + const supabaseAdmin = createClient( + process.env.NEXT_PUBLIC_SUPABASE_URL!, + process.env.SUPABASE_SERVICE_ROLE_KEY! + ) + + const { data: usernames, error } = await supabaseAdmin + .from("profiles") + .select("username") + .eq("username", username) + + if (!usernames) { + throw new Error(error.message) + } + + return new Response(JSON.stringify({ isAvailable: !usernames.length }), { + status: 200 + }) + } catch (error: any) { + const errorMessage = error.error?.message || "An unexpected error occurred" + const errorCode = error.status || 500 + return new Response(JSON.stringify({ message: errorMessage }), { + status: errorCode + }) + } +} diff --git a/chatdesk-ui/app/api/username/get/route.ts b/chatdesk-ui/app/api/username/get/route.ts new file mode 100644 index 0000000..d3cd158 --- /dev/null +++ b/chatdesk-ui/app/api/username/get/route.ts @@ -0,0 +1,38 @@ +import { Database } from "@/supabase/types" +import { createClient } from "@supabase/supabase-js" + +export const runtime = "edge" + +export async function POST(request: Request) { + const json = await request.json() + const { userId } = json as { + userId: string + } + + try { + const supabaseAdmin = createClient( + process.env.NEXT_PUBLIC_SUPABASE_URL!, + process.env.SUPABASE_SERVICE_ROLE_KEY! + ) + + const { data, error } = await supabaseAdmin + .from("profiles") + .select("username") + .eq("user_id", userId) + .single() + + if (!data) { + throw new Error(error.message) + } + + return new Response(JSON.stringify({ username: data.username }), { + status: 200 + }) + } catch (error: any) { + const errorMessage = error.error?.message || "An unexpected error occurred" + const errorCode = error.status || 500 + return new Response(JSON.stringify({ message: errorMessage }), { + status: errorCode + }) + } +} diff --git a/chatdesk-ui/app/auth/callback/route.ts b/chatdesk-ui/app/auth/callback/route.ts new file mode 100644 index 0000000..acf1c65 --- /dev/null +++ b/chatdesk-ui/app/auth/callback/route.ts @@ -0,0 +1,21 @@ +import { createClient } from "@/lib/supabase/server" +import { cookies } from "next/headers" +import { NextResponse } from "next/server" + +export async function GET(request: Request) { + const requestUrl = new URL(request.url) + const code = requestUrl.searchParams.get("code") + const next = requestUrl.searchParams.get("next") + + if (code) { + const cookieStore = cookies() + const supabase = createClient(cookieStore) + await supabase.auth.exchangeCodeForSession(code) + } + + if (next) { + return NextResponse.redirect(requestUrl.origin + next) + } else { + return NextResponse.redirect(requestUrl.origin) + } +} diff --git a/chatdesk-ui/components.json b/chatdesk-ui/components.json new file mode 100644 index 0000000..433a5ad --- /dev/null +++ b/chatdesk-ui/components.json @@ -0,0 +1,16 @@ +{ + "$schema": "https://ui.shadcn.com/schema.json", + "style": "default", + "rsc": true, + "tsx": true, + "tailwind": { + "config": "tailwind.config.js", + "css": "app/globals.css", + "baseColor": "gray", + "cssVariables": true + }, + "aliases": { + "components": "@/components", + "utils": "@/lib/utils" + } +} diff --git a/chatdesk-ui/components/chat/assistant-picker.tsx b/chatdesk-ui/components/chat/assistant-picker.tsx new file mode 100644 index 0000000..04e2bdd --- /dev/null +++ b/chatdesk-ui/components/chat/assistant-picker.tsx @@ -0,0 +1,128 @@ +import { ChatbotUIContext } from "@/context/context" +import { Tables } from "@/supabase/types" +import { IconRobotFace } from "@tabler/icons-react" +import Image from "next/image" +import { FC, useContext, useEffect, useRef } from "react" +import { usePromptAndCommand } from "./chat-hooks/use-prompt-and-command" + +interface AssistantPickerProps {} + +export const AssistantPicker: FC = ({}) => { + const { + assistants, + assistantImages, + focusAssistant, + atCommand, + isAssistantPickerOpen, + setIsAssistantPickerOpen + } = useContext(ChatbotUIContext) + + const { handleSelectAssistant } = usePromptAndCommand() + + const itemsRef = useRef<(HTMLDivElement | null)[]>([]) + + useEffect(() => { + if (focusAssistant && itemsRef.current[0]) { + itemsRef.current[0].focus() + } + }, [focusAssistant]) + + const filteredAssistants = assistants.filter(assistant => + assistant.name.toLowerCase().includes(atCommand.toLowerCase()) + ) + + const handleOpenChange = (isOpen: boolean) => { + setIsAssistantPickerOpen(isOpen) + } + + const callSelectAssistant = (assistant: Tables<"assistants">) => { + handleSelectAssistant(assistant) + handleOpenChange(false) + } + + const getKeyDownHandler = + (index: number) => (e: React.KeyboardEvent) => { + if (e.key === "Backspace") { + e.preventDefault() + handleOpenChange(false) + } else if (e.key === "Enter") { + e.preventDefault() + callSelectAssistant(filteredAssistants[index]) + } else if ( + (e.key === "Tab" || e.key === "ArrowDown") && + !e.shiftKey && + index === filteredAssistants.length - 1 + ) { + e.preventDefault() + itemsRef.current[0]?.focus() + } else if (e.key === "ArrowUp" && !e.shiftKey && index === 0) { + // go to last element if arrow up is pressed on first element + e.preventDefault() + itemsRef.current[itemsRef.current.length - 1]?.focus() + } else if (e.key === "ArrowUp") { + e.preventDefault() + const prevIndex = + index - 1 >= 0 ? index - 1 : itemsRef.current.length - 1 + itemsRef.current[prevIndex]?.focus() + } else if (e.key === "ArrowDown") { + e.preventDefault() + const nextIndex = index + 1 < itemsRef.current.length ? index + 1 : 0 + itemsRef.current[nextIndex]?.focus() + } + } + + return ( + <> + {isAssistantPickerOpen && ( +
+ {filteredAssistants.length === 0 ? ( +
+ No matching assistants. +
+ ) : ( + <> + {filteredAssistants.map((item, index) => ( +
{ + itemsRef.current[index] = ref + }} + tabIndex={0} + className="hover:bg-accent focus:bg-accent flex cursor-pointer items-center rounded p-2 focus:outline-none" + onClick={() => + callSelectAssistant(item as Tables<"assistants">) + } + onKeyDown={getKeyDownHandler(index)} + > + {item.image_path ? ( + image.path === item.image_path + )?.url || "" + } + alt={item.name} + width={32} + height={32} + className="rounded" + /> + ) : ( + + )} + +
+
{item.name}
+ +
+ {item.description || "No description."} +
+
+
+ ))} + + )} +
+ )} + + ) +} diff --git a/chatdesk-ui/components/chat/chat-command-input.tsx b/chatdesk-ui/components/chat/chat-command-input.tsx new file mode 100644 index 0000000..49afb0b --- /dev/null +++ b/chatdesk-ui/components/chat/chat-command-input.tsx @@ -0,0 +1,48 @@ +import { ChatbotUIContext } from "@/context/context" +import { FC, useContext } from "react" +import { AssistantPicker } from "./assistant-picker" +import { usePromptAndCommand } from "./chat-hooks/use-prompt-and-command" +import { FilePicker } from "./file-picker" +import { PromptPicker } from "./prompt-picker" +import { ToolPicker } from "./tool-picker" + +interface ChatCommandInputProps {} + +export const ChatCommandInput: FC = ({}) => { + const { + newMessageFiles, + chatFiles, + slashCommand, + isFilePickerOpen, + setIsFilePickerOpen, + hashtagCommand, + focusPrompt, + focusFile + } = useContext(ChatbotUIContext) + + const { handleSelectUserFile, handleSelectUserCollection } = + usePromptAndCommand() + + return ( + <> + + + file.id + )} + selectedCollectionIds={[]} + onSelectFile={handleSelectUserFile} + onSelectCollection={handleSelectUserCollection} + isFocused={focusFile} + /> + + + + + + ) +} diff --git a/chatdesk-ui/components/chat/chat-files-display.tsx b/chatdesk-ui/components/chat/chat-files-display.tsx new file mode 100644 index 0000000..a067505 --- /dev/null +++ b/chatdesk-ui/components/chat/chat-files-display.tsx @@ -0,0 +1,283 @@ +import { ChatbotUIContext } from "@/context/context" +import { getFileFromStorage } from "@/db/storage/files" +import useHotkey from "@/lib/hooks/use-hotkey" +import { cn } from "@/lib/utils" +import { ChatFile, MessageImage } from "@/types" +import { + IconCircleFilled, + IconFileFilled, + IconFileTypeCsv, + IconFileTypeDocx, + IconFileTypePdf, + IconFileTypeTxt, + IconJson, + IconLoader2, + IconMarkdown, + IconX +} from "@tabler/icons-react" +import Image from "next/image" +import { FC, useContext, useState } from "react" +import { Button } from "../ui/button" +import { FilePreview } from "../ui/file-preview" +import { WithTooltip } from "../ui/with-tooltip" +import { ChatRetrievalSettings } from "./chat-retrieval-settings" + +interface ChatFilesDisplayProps {} + +export const ChatFilesDisplay: FC = ({}) => { + useHotkey("f", () => setShowFilesDisplay(prev => !prev)) + useHotkey("e", () => setUseRetrieval(prev => !prev)) + + const { + files, + newMessageImages, + setNewMessageImages, + newMessageFiles, + setNewMessageFiles, + setShowFilesDisplay, + showFilesDisplay, + chatFiles, + chatImages, + setChatImages, + setChatFiles, + setUseRetrieval + } = useContext(ChatbotUIContext) + + const [selectedFile, setSelectedFile] = useState(null) + const [selectedImage, setSelectedImage] = useState(null) + const [showPreview, setShowPreview] = useState(false) + + const messageImages = [ + ...newMessageImages.filter( + image => + !chatImages.some(chatImage => chatImage.messageId === image.messageId) + ) + ] + + const combinedChatFiles = [ + ...newMessageFiles.filter( + file => !chatFiles.some(chatFile => chatFile.id === file.id) + ), + ...chatFiles + ] + + const combinedMessageFiles = [...messageImages, ...combinedChatFiles] + + const getLinkAndView = async (file: ChatFile) => { + const fileRecord = files.find(f => f.id === file.id) + + if (!fileRecord) return + + const link = await getFileFromStorage(fileRecord.file_path) + window.open(link, "_blank") + } + + return showFilesDisplay && combinedMessageFiles.length > 0 ? ( + <> + {showPreview && selectedImage && ( + { + setShowPreview(isOpen) + setSelectedImage(null) + }} + /> + )} + + {showPreview && selectedFile && ( + { + setShowPreview(isOpen) + setSelectedFile(null) + }} + /> + )} + +
+
+ +
+ +
+
+ {messageImages.map((image, index) => ( +
+ File image { + setSelectedImage(image) + setShowPreview(true) + }} + /> + + { + e.stopPropagation() + setNewMessageImages( + newMessageImages.filter( + f => f.messageId !== image.messageId + ) + ) + setChatImages( + chatImages.filter(f => f.messageId !== image.messageId) + ) + }} + /> +
+ ))} + + {combinedChatFiles.map((file, index) => + file.id === "loading" ? ( +
+
+ +
+ +
+
{file.name}
+
{file.type}
+
+
+ ) : ( +
getLinkAndView(file)} + > +
+ {(() => { + let fileExtension = file.type.includes("/") + ? file.type.split("/")[1] + : file.type + + switch (fileExtension) { + case "pdf": + return + case "markdown": + return + case "txt": + return + case "json": + return + case "csv": + return + case "docx": + return + default: + return + } + })()} +
+ +
+
{file.name}
+
+ + { + e.stopPropagation() + setNewMessageFiles( + newMessageFiles.filter(f => f.id !== file.id) + ) + setChatFiles(chatFiles.filter(f => f.id !== file.id)) + }} + /> +
+ ) + )} +
+
+
+ + ) : ( + combinedMessageFiles.length > 0 && ( +
+ +
+ ) + ) +} + +const RetrievalToggle = ({}) => { + const { useRetrieval, setUseRetrieval } = useContext(ChatbotUIContext) + + return ( +
+ + {useRetrieval + ? "File retrieval is enabled on the selected files for this message. Click the indicator to disable." + : "Click the indicator to enable file retrieval for this message."} +
+ } + trigger={ + { + e.stopPropagation() + setUseRetrieval(prev => !prev) + }} + /> + } + /> +
+ ) +} diff --git a/chatdesk-ui/components/chat/chat-help.tsx b/chatdesk-ui/components/chat/chat-help.tsx new file mode 100644 index 0000000..c0ec844 --- /dev/null +++ b/chatdesk-ui/components/chat/chat-help.tsx @@ -0,0 +1,213 @@ +import useHotkey from "@/lib/hooks/use-hotkey" +import { + IconBrandGithub, + IconBrandX, + IconHelpCircle, + IconQuestionMark +} from "@tabler/icons-react" +import Link from "next/link" +import { FC, useState } from "react" +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuLabel, + DropdownMenuSeparator, + DropdownMenuTrigger +} from "../ui/dropdown-menu" +import { Announcements } from "../utility/announcements" + +import { useTranslation } from 'react-i18next' + +interface ChatHelpProps {} + +export const ChatHelp: FC = ({}) => { + + const { t } = useTranslation() + + useHotkey("/", () => setIsOpen(prevState => !prevState)) + + const [isOpen, setIsOpen] = useState(false) + + return ( + + + + + + + +
+ + + + + + + +
+ +
+ + + + + +
+
+ + + + +
{t("help.showHelp")}
+
+
+ ⌘ +
+
+ Shift +
+
+ / +
+
+
+ + +
{t("help.showWorkspaces")}
+
+
+ ⌘ +
+
+ Shift +
+
+ ; +
+
+
+ + +
{t("help.newChat")}
+
+
+ ⌘ +
+
+ Shift +
+
+ O +
+
+
+ + +
{t("help.focusChat")}
+
+
+ ⌘ +
+
+ Shift +
+
+ L +
+
+
+ + +
{t("help.toggleFiles")}
+
+
+ ⌘ +
+
+ Shift +
+
+ F +
+
+
+ + +
{t("help.toggleRetrieval")}
+
+
+ ⌘ +
+
+ Shift +
+
+ E +
+
+
+ + +
{t("help.openSettings")}
+
+
+ ⌘ +
+
+ Shift +
+
+ I +
+
+
+ + +
{t("help.openQuickSettings")}
+
+
+ ⌘ +
+
+ Shift +
+
+ P +
+
+
+ + +
{t("help.toggleSidebar")}
+
+
+ ⌘ +
+
+ Shift +
+
+ S +
+
+
+
+
+ ) +} diff --git a/chatdesk-ui/components/chat/chat-helpers/index.ts b/chatdesk-ui/components/chat/chat-helpers/index.ts new file mode 100644 index 0000000..17a2089 --- /dev/null +++ b/chatdesk-ui/components/chat/chat-helpers/index.ts @@ -0,0 +1,511 @@ +// Only used in use-chat-handler.tsx to keep it clean + +import { createChatFiles } from "@/db/chat-files" +import { createChat } from "@/db/chats" +import { createMessageFileItems } from "@/db/message-file-items" +import { createMessages, updateMessage } from "@/db/messages" +import { uploadMessageImage } from "@/db/storage/message-images" +import { + buildFinalMessages, + adaptMessagesForGoogleGemini +} from "@/lib/build-prompt" +import { consumeReadableStream } from "@/lib/consume-stream" +import { Tables, TablesInsert } from "@/supabase/types" +import { + ChatFile, + ChatMessage, + ChatPayload, + ChatSettings, + LLM, + MessageImage +} from "@/types" +import React from "react" +import { toast } from "sonner" +import { v4 as uuidv4 } from "uuid" + +export const validateChatSettings = ( + chatSettings: ChatSettings | null, + modelData: LLM | undefined, + profile: Tables<"profiles"> | null, + selectedWorkspace: Tables<"workspaces"> | null, + messageContent: string +) => { + if (!chatSettings) { + throw new Error("Chat settings not found") + } + + if (!modelData) { + throw new Error("Model not found") + } + + if (!profile) { + throw new Error("Profile not found") + } + + if (!selectedWorkspace) { + throw new Error("Workspace not found") + } + + if (!messageContent) { + throw new Error("Message content not found") + } +} + +export const handleRetrieval = async ( + userInput: string, + newMessageFiles: ChatFile[], + chatFiles: ChatFile[], + embeddingsProvider: "openai" | "local", + sourceCount: number +) => { + const response = await fetch("/api/retrieval/retrieve", { + method: "POST", + body: JSON.stringify({ + userInput, + fileIds: [...newMessageFiles, ...chatFiles].map(file => file.id), + embeddingsProvider, + sourceCount + }) + }) + + if (!response.ok) { + console.error("Error retrieving:", response) + } + + const { results } = (await response.json()) as { + results: Tables<"file_items">[] + } + + return results +} + +export const createTempMessages = ( + messageContent: string, + chatMessages: ChatMessage[], + chatSettings: ChatSettings, + b64Images: string[], + isRegeneration: boolean, + setChatMessages: React.Dispatch>, + selectedAssistant: Tables<"assistants"> | null +) => { + let tempUserChatMessage: ChatMessage = { + message: { + chat_id: "", + assistant_id: null, + content: messageContent, + created_at: "", + id: uuidv4(), + image_paths: b64Images, + model: chatSettings.model, + role: "user", + sequence_number: chatMessages.length, + updated_at: "", + user_id: "" + }, + fileItems: [] + } + + let tempAssistantChatMessage: ChatMessage = { + message: { + chat_id: "", + assistant_id: selectedAssistant?.id || null, + content: "", + created_at: "", + id: uuidv4(), + image_paths: [], + model: chatSettings.model, + role: "assistant", + sequence_number: chatMessages.length + 1, + updated_at: "", + user_id: "" + }, + fileItems: [] + } + + let newMessages = [] + + if (isRegeneration) { + const lastMessageIndex = chatMessages.length - 1 + chatMessages[lastMessageIndex].message.content = "" + newMessages = [...chatMessages] + } else { + newMessages = [ + ...chatMessages, + tempUserChatMessage, + tempAssistantChatMessage + ] + } + + setChatMessages(newMessages) + + return { + tempUserChatMessage, + tempAssistantChatMessage + } +} + +export const handleLocalChat = async ( + payload: ChatPayload, + profile: Tables<"profiles">, + chatSettings: ChatSettings, + tempAssistantMessage: ChatMessage, + isRegeneration: boolean, + newAbortController: AbortController, + setIsGenerating: React.Dispatch>, + setFirstTokenReceived: React.Dispatch>, + setChatMessages: React.Dispatch>, + setToolInUse: React.Dispatch> +) => { + const formattedMessages = await buildFinalMessages(payload, profile, []) + + // Ollama API: https://github.com/jmorganca/ollama/blob/main/docs/api.md + const response = await fetchChatResponse( + process.env.NEXT_PUBLIC_OLLAMA_URL + "/api/chat", + { + model: chatSettings.model, + messages: formattedMessages, + options: { + temperature: payload.chatSettings.temperature + } + }, + false, + newAbortController, + setIsGenerating, + setChatMessages + ) + + return await processResponse( + response, + isRegeneration + ? payload.chatMessages[payload.chatMessages.length - 1] + : tempAssistantMessage, + false, + newAbortController, + setFirstTokenReceived, + setChatMessages, + setToolInUse + ) +} + +export const handleHostedChat = async ( + payload: ChatPayload, + profile: Tables<"profiles">, + modelData: LLM, + tempAssistantChatMessage: ChatMessage, + isRegeneration: boolean, + newAbortController: AbortController, + newMessageImages: MessageImage[], + chatImages: MessageImage[], + setIsGenerating: React.Dispatch>, + setFirstTokenReceived: React.Dispatch>, + setChatMessages: React.Dispatch>, + setToolInUse: React.Dispatch> +) => { + const provider = + modelData.provider === "openai" && profile.use_azure_openai + ? "azure" + : modelData.provider + + let draftMessages = await buildFinalMessages(payload, profile, chatImages) + + let formattedMessages : any[] = [] + if (provider === "google") { + formattedMessages = await adaptMessagesForGoogleGemini(payload, draftMessages) + } else { + formattedMessages = draftMessages + } + + const apiEndpoint = + provider === "custom" ? "/api/chat/custom" : `/api/chat/${provider}` + + const requestBody = { + chatSettings: payload.chatSettings, + messages: formattedMessages, + customModelId: provider === "custom" ? modelData.hostedId : "" + } + + const response = await fetchChatResponse( + apiEndpoint, + requestBody, + true, + newAbortController, + setIsGenerating, + setChatMessages + ) + + return await processResponse( + response, + isRegeneration + ? payload.chatMessages[payload.chatMessages.length - 1] + : tempAssistantChatMessage, + true, + newAbortController, + setFirstTokenReceived, + setChatMessages, + setToolInUse + ) +} + +export const fetchChatResponse = async ( + url: string, + body: object, + isHosted: boolean, + controller: AbortController, + setIsGenerating: React.Dispatch>, + setChatMessages: React.Dispatch> +) => { + const response = await fetch(url, { + method: "POST", + body: JSON.stringify(body), + signal: controller.signal + }) + + if (!response.ok) { + if (response.status === 404 && !isHosted) { + toast.error( + "Model not found. Make sure you have it downloaded via Ollama." + ) + } + + const errorData = await response.json() + + toast.error(errorData.message) + + setIsGenerating(false) + setChatMessages(prevMessages => prevMessages.slice(0, -2)) + } + + return response +} + +export const processResponse = async ( + response: Response, + lastChatMessage: ChatMessage, + isHosted: boolean, + controller: AbortController, + setFirstTokenReceived: React.Dispatch>, + setChatMessages: React.Dispatch>, + setToolInUse: React.Dispatch> +) => { + let fullText = "" + let contentToAdd = "" + + if (response.body) { + await consumeReadableStream( + response.body, + chunk => { + setFirstTokenReceived(true) + setToolInUse("none") + + try { + contentToAdd = isHosted + ? chunk + : // Ollama's streaming endpoint returns new-line separated JSON + // objects. A chunk may have more than one of these objects, so we + // need to split the chunk by new-lines and handle each one + // separately. + chunk + .trimEnd() + .split("\n") + .reduce( + (acc, line) => acc + JSON.parse(line).message.content, + "" + ) + fullText += contentToAdd + } catch (error) { + console.error("Error parsing JSON:", error) + } + + setChatMessages(prev => + prev.map(chatMessage => { + if (chatMessage.message.id === lastChatMessage.message.id) { + const updatedChatMessage: ChatMessage = { + message: { + ...chatMessage.message, + content: fullText + }, + fileItems: chatMessage.fileItems + } + + return updatedChatMessage + } + + return chatMessage + }) + ) + }, + controller.signal + ) + + return fullText + } else { + throw new Error("Response body is null") + } +} + +export const handleCreateChat = async ( + chatSettings: ChatSettings, + profile: Tables<"profiles">, + selectedWorkspace: Tables<"workspaces">, + messageContent: string, + selectedAssistant: Tables<"assistants">, + newMessageFiles: ChatFile[], + setSelectedChat: React.Dispatch | null>>, + setChats: React.Dispatch[]>>, + setChatFiles: React.Dispatch> +) => { + const createdChat = await createChat({ + user_id: profile.user_id, + workspace_id: selectedWorkspace.id, + assistant_id: selectedAssistant?.id || null, + context_length: chatSettings.contextLength, + include_profile_context: chatSettings.includeProfileContext, + include_workspace_instructions: chatSettings.includeWorkspaceInstructions, + model: chatSettings.model, + name: messageContent.substring(0, 100), + prompt: chatSettings.prompt, + temperature: chatSettings.temperature, + embeddings_provider: chatSettings.embeddingsProvider + }) + + setSelectedChat(createdChat) + setChats(chats => [createdChat, ...chats]) + + await createChatFiles( + newMessageFiles.map(file => ({ + user_id: profile.user_id, + chat_id: createdChat.id, + file_id: file.id + })) + ) + + setChatFiles(prev => [...prev, ...newMessageFiles]) + + return createdChat +} + +export const handleCreateMessages = async ( + chatMessages: ChatMessage[], + currentChat: Tables<"chats">, + profile: Tables<"profiles">, + modelData: LLM, + messageContent: string, + generatedText: string, + newMessageImages: MessageImage[], + isRegeneration: boolean, + retrievedFileItems: Tables<"file_items">[], + setChatMessages: React.Dispatch>, + setChatFileItems: React.Dispatch< + React.SetStateAction[]> + >, + setChatImages: React.Dispatch>, + selectedAssistant: Tables<"assistants"> | null +) => { + const finalUserMessage: TablesInsert<"messages"> = { + chat_id: currentChat.id, + assistant_id: null, + user_id: profile.user_id, + content: messageContent, + model: modelData.modelId, + role: "user", + sequence_number: chatMessages.length, + image_paths: [] + } + + const finalAssistantMessage: TablesInsert<"messages"> = { + chat_id: currentChat.id, + assistant_id: selectedAssistant?.id || null, + user_id: profile.user_id, + content: generatedText, + model: modelData.modelId, + role: "assistant", + sequence_number: chatMessages.length + 1, + image_paths: [] + } + + let finalChatMessages: ChatMessage[] = [] + + if (isRegeneration) { + const lastStartingMessage = chatMessages[chatMessages.length - 1].message + + const updatedMessage = await updateMessage(lastStartingMessage.id, { + ...lastStartingMessage, + content: generatedText + }) + + chatMessages[chatMessages.length - 1].message = updatedMessage + + finalChatMessages = [...chatMessages] + + setChatMessages(finalChatMessages) + } else { + const createdMessages = await createMessages([ + finalUserMessage, + finalAssistantMessage + ]) + + // Upload each image (stored in newMessageImages) for the user message to message_images bucket + const uploadPromises = newMessageImages + .filter(obj => obj.file !== null) + .map(obj => { + let filePath = `${profile.user_id}/${currentChat.id}/${ + createdMessages[0].id + }/${uuidv4()}` + + return uploadMessageImage(filePath, obj.file as File).catch(error => { + console.error(`Failed to upload image at ${filePath}:`, error) + return null + }) + }) + + const paths = (await Promise.all(uploadPromises)).filter( + Boolean + ) as string[] + + setChatImages(prevImages => [ + ...prevImages, + ...newMessageImages.map((obj, index) => ({ + ...obj, + messageId: createdMessages[0].id, + path: paths[index] + })) + ]) + + const updatedMessage = await updateMessage(createdMessages[0].id, { + ...createdMessages[0], + image_paths: paths + }) + + const createdMessageFileItems = await createMessageFileItems( + retrievedFileItems.map(fileItem => { + return { + user_id: profile.user_id, + message_id: createdMessages[1].id, + file_item_id: fileItem.id + } + }) + ) + + finalChatMessages = [ + ...chatMessages, + { + message: updatedMessage, + fileItems: [] + }, + { + message: createdMessages[1], + fileItems: retrievedFileItems.map(fileItem => fileItem.id) + } + ] + + setChatFileItems(prevFileItems => { + const newFileItems = retrievedFileItems.filter( + fileItem => !prevFileItems.some(prevItem => prevItem.id === fileItem.id) + ) + + return [...prevFileItems, ...newFileItems] + }) + + setChatMessages(finalChatMessages) + } +} diff --git a/chatdesk-ui/components/chat/chat-hooks/use-chat-handler.tsx b/chatdesk-ui/components/chat/chat-hooks/use-chat-handler.tsx new file mode 100644 index 0000000..9eba82f --- /dev/null +++ b/chatdesk-ui/components/chat/chat-hooks/use-chat-handler.tsx @@ -0,0 +1,450 @@ +import { ChatbotUIContext } from "@/context/context" +import { getAssistantCollectionsByAssistantId } from "@/db/assistant-collections" +import { getAssistantFilesByAssistantId } from "@/db/assistant-files" +import { getAssistantToolsByAssistantId } from "@/db/assistant-tools" +import { updateChat } from "@/db/chats" +import { getCollectionFilesByCollectionId } from "@/db/collection-files" +import { deleteMessagesIncludingAndAfter } from "@/db/messages" +import { buildFinalMessages } from "@/lib/build-prompt" +import { Tables } from "@/supabase/types" +import { ChatMessage, ChatPayload, LLMID, ModelProvider } from "@/types" +import { useRouter } from "next/navigation" +import { useContext, useEffect, useRef } from "react" +import { LLM_LIST } from "../../../lib/models/llm/llm-list" + +import i18nConfig from "@/i18nConfig" + +import { + createTempMessages, + handleCreateChat, + handleCreateMessages, + handleHostedChat, + handleLocalChat, + handleRetrieval, + processResponse, + validateChatSettings +} from "../chat-helpers" + +import { usePathname } from "next/navigation" + +export const useChatHandler = () => { + const pathname = usePathname() // 获取当前路径 + const router = useRouter() + + // 提取当前路径中的 locale 部分 + // const locale = pathname.split("/")[1] || "en" + + const { + userInput, + chatFiles, + setUserInput, + setNewMessageImages, + profile, + setIsGenerating, + setChatMessages, + setFirstTokenReceived, + selectedChat, + selectedWorkspace, + setSelectedChat, + setChats, + setSelectedTools, + availableLocalModels, + availableOpenRouterModels, + abortController, + setAbortController, + chatSettings, + newMessageImages, + selectedAssistant, + chatMessages, + chatImages, + setChatImages, + setChatFiles, + setNewMessageFiles, + setShowFilesDisplay, + newMessageFiles, + chatFileItems, + setChatFileItems, + setToolInUse, + useRetrieval, + sourceCount, + setIsPromptPickerOpen, + setIsFilePickerOpen, + selectedTools, + selectedPreset, + setChatSettings, + models, + isPromptPickerOpen, + isFilePickerOpen, + isToolPickerOpen + } = useContext(ChatbotUIContext) + + const chatInputRef = useRef(null) + + useEffect(() => { + if (!isPromptPickerOpen || !isFilePickerOpen || !isToolPickerOpen) { + chatInputRef.current?.focus() + } + }, [isPromptPickerOpen, isFilePickerOpen, isToolPickerOpen]) + + const handleNewChat = async () => { + if (!selectedWorkspace) return + + setUserInput("") + setChatMessages([]) + setSelectedChat(null) + setChatFileItems([]) + + setIsGenerating(false) + setFirstTokenReceived(false) + + setChatFiles([]) + setChatImages([]) + setNewMessageFiles([]) + setNewMessageImages([]) + setShowFilesDisplay(false) + setIsPromptPickerOpen(false) + setIsFilePickerOpen(false) + + setSelectedTools([]) + setToolInUse("none") + + if (selectedAssistant) { + setChatSettings({ + model: selectedAssistant.model as LLMID, + prompt: selectedAssistant.prompt, + temperature: selectedAssistant.temperature, + contextLength: selectedAssistant.context_length, + includeProfileContext: selectedAssistant.include_profile_context, + includeWorkspaceInstructions: + selectedAssistant.include_workspace_instructions, + embeddingsProvider: selectedAssistant.embeddings_provider as + | "openai" + | "local" + }) + + let allFiles = [] + + const assistantFiles = ( + await getAssistantFilesByAssistantId(selectedAssistant.id) + ).files + allFiles = [...assistantFiles] + const assistantCollections = ( + await getAssistantCollectionsByAssistantId(selectedAssistant.id) + ).collections + for (const collection of assistantCollections) { + const collectionFiles = ( + await getCollectionFilesByCollectionId(collection.id) + ).files + allFiles = [...allFiles, ...collectionFiles] + } + const assistantTools = ( + await getAssistantToolsByAssistantId(selectedAssistant.id) + ).tools + + setSelectedTools(assistantTools) + setChatFiles( + allFiles.map(file => ({ + id: file.id, + name: file.name, + type: file.type, + file: null + })) + ) + + if (allFiles.length > 0) setShowFilesDisplay(true) + } else if (selectedPreset) { + setChatSettings({ + model: selectedPreset.model as LLMID, + prompt: selectedPreset.prompt, + temperature: selectedPreset.temperature, + contextLength: selectedPreset.context_length, + includeProfileContext: selectedPreset.include_profile_context, + includeWorkspaceInstructions: + selectedPreset.include_workspace_instructions, + embeddingsProvider: selectedPreset.embeddings_provider as + | "openai" + | "local" + }) + } else if (selectedWorkspace) { + // setChatSettings({ + // model: (selectedWorkspace.default_model || + // "gpt-4-1106-preview") as LLMID, + // prompt: + // selectedWorkspace.default_prompt || + // "You are a friendly, helpful AI assistant.", + // temperature: selectedWorkspace.default_temperature || 0.5, + // contextLength: selectedWorkspace.default_context_length || 4096, + // includeProfileContext: + // selectedWorkspace.include_profile_context || true, + // includeWorkspaceInstructions: + // selectedWorkspace.include_workspace_instructions || true, + // embeddingsProvider: + // (selectedWorkspace.embeddings_provider as "openai" | "local") || + // "openai" + // }) + } + + + const pathSegments = pathname.split("/").filter(Boolean) + const locales = i18nConfig.locales + const defaultLocale = i18nConfig.defaultLocale + + let locale: (typeof locales)[number] = defaultLocale + const segment = pathSegments[0] as (typeof locales)[number] + + if (locales.includes(segment)) { + locale = segment + } + + // ✅ 正确构造 localePrefix,不包含前导 / + const localePrefix = locale === defaultLocale ? "" : `/${locale}` + + console.log("[use-chat-handler.tsx]...........localePrefix", localePrefix) + + return router.push(`${localePrefix}/${selectedWorkspace.id}/chat`) + + // return router.push(`/${locale}/${selectedWorkspace.id}/chat`) + } + + const handleFocusChatInput = () => { + chatInputRef.current?.focus() + } + + const handleStopMessage = () => { + if (abortController) { + abortController.abort() + } + } + + const handleSendMessage = async ( + messageContent: string, + chatMessages: ChatMessage[], + isRegeneration: boolean + ) => { + const startingInput = messageContent + + try { + setUserInput("") + setIsGenerating(true) + setIsPromptPickerOpen(false) + setIsFilePickerOpen(false) + setNewMessageImages([]) + + const newAbortController = new AbortController() + setAbortController(newAbortController) + + const modelData = [ + ...models.map(model => ({ + modelId: model.model_id as LLMID, + modelName: model.name, + provider: "custom" as ModelProvider, + hostedId: model.id, + platformLink: "", + imageInput: false + })), + ...LLM_LIST, + ...availableLocalModels, + ...availableOpenRouterModels + ].find(llm => llm.modelId === chatSettings?.model) + + validateChatSettings( + chatSettings, + modelData, + profile, + selectedWorkspace, + messageContent + ) + + let currentChat = selectedChat ? { ...selectedChat } : null + + const b64Images = newMessageImages.map(image => image.base64) + + let retrievedFileItems: Tables<"file_items">[] = [] + + if ( + (newMessageFiles.length > 0 || chatFiles.length > 0) && + useRetrieval + ) { + setToolInUse("retrieval") + + retrievedFileItems = await handleRetrieval( + userInput, + newMessageFiles, + chatFiles, + chatSettings!.embeddingsProvider, + sourceCount + ) + } + + const { tempUserChatMessage, tempAssistantChatMessage } = + createTempMessages( + messageContent, + chatMessages, + chatSettings!, + b64Images, + isRegeneration, + setChatMessages, + selectedAssistant + ) + + let payload: ChatPayload = { + chatSettings: chatSettings!, + workspaceInstructions: selectedWorkspace!.instructions || "", + chatMessages: isRegeneration + ? [...chatMessages] + : [...chatMessages, tempUserChatMessage], + assistant: selectedChat?.assistant_id ? selectedAssistant : null, + messageFileItems: retrievedFileItems, + chatFileItems: chatFileItems + } + + let generatedText = "" + + if (selectedTools.length > 0) { + setToolInUse("Tools") + + const formattedMessages = await buildFinalMessages( + payload, + profile!, + chatImages + ) + + const response = await fetch("/api/chat/tools", { + method: "POST", + headers: { + "Content-Type": "application/json" + }, + body: JSON.stringify({ + chatSettings: payload.chatSettings, + messages: formattedMessages, + selectedTools + }) + }) + + setToolInUse("none") + + generatedText = await processResponse( + response, + isRegeneration + ? payload.chatMessages[payload.chatMessages.length - 1] + : tempAssistantChatMessage, + true, + newAbortController, + setFirstTokenReceived, + setChatMessages, + setToolInUse + ) + } else { + if (modelData!.provider === "ollama") { + generatedText = await handleLocalChat( + payload, + profile!, + chatSettings!, + tempAssistantChatMessage, + isRegeneration, + newAbortController, + setIsGenerating, + setFirstTokenReceived, + setChatMessages, + setToolInUse + ) + } else { + generatedText = await handleHostedChat( + payload, + profile!, + modelData!, + tempAssistantChatMessage, + isRegeneration, + newAbortController, + newMessageImages, + chatImages, + setIsGenerating, + setFirstTokenReceived, + setChatMessages, + setToolInUse + ) + } + } + + if (!currentChat) { + currentChat = await handleCreateChat( + chatSettings!, + profile!, + selectedWorkspace!, + messageContent, + selectedAssistant!, + newMessageFiles, + setSelectedChat, + setChats, + setChatFiles + ) + } else { + const updatedChat = await updateChat(currentChat.id, { + updated_at: new Date().toISOString() + }) + + setChats(prevChats => { + const updatedChats = prevChats.map(prevChat => + prevChat.id === updatedChat.id ? updatedChat : prevChat + ) + + return updatedChats + }) + } + + await handleCreateMessages( + chatMessages, + currentChat, + profile!, + modelData!, + messageContent, + generatedText, + newMessageImages, + isRegeneration, + retrievedFileItems, + setChatMessages, + setChatFileItems, + setChatImages, + selectedAssistant + ) + + setIsGenerating(false) + setFirstTokenReceived(false) + } catch (error) { + setIsGenerating(false) + setFirstTokenReceived(false) + setUserInput(startingInput) + } + } + + const handleSendEdit = async ( + editedContent: string, + sequenceNumber: number + ) => { + if (!selectedChat) return + + await deleteMessagesIncludingAndAfter( + selectedChat.user_id, + selectedChat.id, + sequenceNumber + ) + + const filteredMessages = chatMessages.filter( + chatMessage => chatMessage.message.sequence_number < sequenceNumber + ) + + setChatMessages(filteredMessages) + + handleSendMessage(editedContent, filteredMessages, false) + } + + return { + chatInputRef, + prompt, + handleNewChat, + handleSendMessage, + handleFocusChatInput, + handleStopMessage, + handleSendEdit + } +} diff --git a/chatdesk-ui/components/chat/chat-hooks/use-chat-history.tsx b/chatdesk-ui/components/chat/chat-hooks/use-chat-history.tsx new file mode 100644 index 0000000..cbb376e --- /dev/null +++ b/chatdesk-ui/components/chat/chat-hooks/use-chat-history.tsx @@ -0,0 +1,77 @@ +import { ChatbotUIContext } from "@/context/context" +import { useContext, useEffect, useState } from "react" + +/** + * Custom hook for handling chat history in the chat component. + * It provides functions to set the new message content to the previous or next user message in the chat history. + * + * @returns An object containing the following functions: + * - setNewMessageContentToPreviousUserMessage: Sets the new message content to the previous user message. + * - setNewMessageContentToNextUserMessage: Sets the new message content to the next user message in the chat history. + */ +export const useChatHistoryHandler = () => { + const { setUserInput, chatMessages, isGenerating } = + useContext(ChatbotUIContext) + const userRoleString = "user" + + const [messageHistoryIndex, setMessageHistoryIndex] = useState( + chatMessages.length + ) + + useEffect(() => { + // If messages get deleted the history index pointed could be out of bounds + if (!isGenerating && messageHistoryIndex > chatMessages.length) + setMessageHistoryIndex(chatMessages.length) + }, [chatMessages, isGenerating, messageHistoryIndex]) + + /** + * Sets the new message content to the previous user message. + */ + const setNewMessageContentToPreviousUserMessage = () => { + let tempIndex = messageHistoryIndex + while ( + tempIndex > 0 && + chatMessages[tempIndex - 1].message.role !== userRoleString + ) { + tempIndex-- + } + + const previousUserMessage = + chatMessages.length > 0 && tempIndex > 0 + ? chatMessages[tempIndex - 1] + : null + if (previousUserMessage) { + setUserInput(previousUserMessage.message.content) + setMessageHistoryIndex(tempIndex - 1) + } + } + + /** + * Sets the new message content to the next user message in the chat history. + * If there is a next user message, it updates the user input and message history index accordingly. + * If there is no next user message, it resets the user input and sets the message history index to the end of the chat history. + */ + const setNewMessageContentToNextUserMessage = () => { + let tempIndex = messageHistoryIndex + while ( + tempIndex < chatMessages.length - 1 && + chatMessages[tempIndex + 1].message.role !== userRoleString + ) { + tempIndex++ + } + + const nextUserMessage = + chatMessages.length > 0 && tempIndex < chatMessages.length - 1 + ? chatMessages[tempIndex + 1] + : null + setUserInput(nextUserMessage?.message.content || "") + setMessageHistoryIndex( + nextUserMessage ? tempIndex + 1 : chatMessages.length + ) + } + + return { + setNewMessageContentToPreviousUserMessage, + setNewMessageContentToNextUserMessage + } +} diff --git a/chatdesk-ui/components/chat/chat-hooks/use-prompt-and-command.tsx b/chatdesk-ui/components/chat/chat-hooks/use-prompt-and-command.tsx new file mode 100644 index 0000000..aaa1925 --- /dev/null +++ b/chatdesk-ui/components/chat/chat-hooks/use-prompt-and-command.tsx @@ -0,0 +1,190 @@ +import { ChatbotUIContext } from "@/context/context" +import { getAssistantCollectionsByAssistantId } from "@/db/assistant-collections" +import { getAssistantFilesByAssistantId } from "@/db/assistant-files" +import { getAssistantToolsByAssistantId } from "@/db/assistant-tools" +import { getCollectionFilesByCollectionId } from "@/db/collection-files" +import { Tables } from "@/supabase/types" +import { LLMID } from "@/types" +import { useContext } from "react" + +export const usePromptAndCommand = () => { + const { + chatFiles, + setNewMessageFiles, + userInput, + setUserInput, + setShowFilesDisplay, + setIsPromptPickerOpen, + setIsFilePickerOpen, + setSlashCommand, + setHashtagCommand, + setUseRetrieval, + setToolCommand, + setIsToolPickerOpen, + setSelectedTools, + setAtCommand, + setIsAssistantPickerOpen, + setSelectedAssistant, + setChatSettings, + setChatFiles + } = useContext(ChatbotUIContext) + + const handleInputChange = (value: string) => { + const atTextRegex = /@([^ ]*)$/ + const slashTextRegex = /\/([^ ]*)$/ + const hashtagTextRegex = /#([^ ]*)$/ + const toolTextRegex = /!([^ ]*)$/ + const atMatch = value.match(atTextRegex) + const slashMatch = value.match(slashTextRegex) + const hashtagMatch = value.match(hashtagTextRegex) + const toolMatch = value.match(toolTextRegex) + + if (atMatch) { + setIsAssistantPickerOpen(true) + setAtCommand(atMatch[1]) + } else if (slashMatch) { + setIsPromptPickerOpen(true) + setSlashCommand(slashMatch[1]) + } else if (hashtagMatch) { + setIsFilePickerOpen(true) + setHashtagCommand(hashtagMatch[1]) + } else if (toolMatch) { + setIsToolPickerOpen(true) + setToolCommand(toolMatch[1]) + } else { + setIsPromptPickerOpen(false) + setIsFilePickerOpen(false) + setIsToolPickerOpen(false) + setIsAssistantPickerOpen(false) + setSlashCommand("") + setHashtagCommand("") + setToolCommand("") + setAtCommand("") + } + + setUserInput(value) + } + + const handleSelectPrompt = (prompt: Tables<"prompts">) => { + setIsPromptPickerOpen(false) + setUserInput(userInput.replace(/\/[^ ]*$/, "") + prompt.content) + } + + const handleSelectUserFile = async (file: Tables<"files">) => { + setShowFilesDisplay(true) + setIsFilePickerOpen(false) + setUseRetrieval(true) + + setNewMessageFiles(prev => { + const fileAlreadySelected = + prev.some(prevFile => prevFile.id === file.id) || + chatFiles.some(chatFile => chatFile.id === file.id) + + if (!fileAlreadySelected) { + return [ + ...prev, + { + id: file.id, + name: file.name, + type: file.type, + file: null + } + ] + } + return prev + }) + + setUserInput(userInput.replace(/#[^ ]*$/, "")) + } + + const handleSelectUserCollection = async ( + collection: Tables<"collections"> + ) => { + setShowFilesDisplay(true) + setIsFilePickerOpen(false) + setUseRetrieval(true) + + const collectionFiles = await getCollectionFilesByCollectionId( + collection.id + ) + + setNewMessageFiles(prev => { + const newFiles = collectionFiles.files + .filter( + file => + !prev.some(prevFile => prevFile.id === file.id) && + !chatFiles.some(chatFile => chatFile.id === file.id) + ) + .map(file => ({ + id: file.id, + name: file.name, + type: file.type, + file: null + })) + + return [...prev, ...newFiles] + }) + + setUserInput(userInput.replace(/#[^ ]*$/, "")) + } + + const handleSelectTool = (tool: Tables<"tools">) => { + setIsToolPickerOpen(false) + setUserInput(userInput.replace(/![^ ]*$/, "")) + setSelectedTools(prev => [...prev, tool]) + } + + const handleSelectAssistant = async (assistant: Tables<"assistants">) => { + setIsAssistantPickerOpen(false) + setUserInput(userInput.replace(/@[^ ]*$/, "")) + setSelectedAssistant(assistant) + + setChatSettings({ + model: assistant.model as LLMID, + prompt: assistant.prompt, + temperature: assistant.temperature, + contextLength: assistant.context_length, + includeProfileContext: assistant.include_profile_context, + includeWorkspaceInstructions: assistant.include_workspace_instructions, + embeddingsProvider: assistant.embeddings_provider as "openai" | "local" + }) + + let allFiles = [] + + const assistantFiles = (await getAssistantFilesByAssistantId(assistant.id)) + .files + allFiles = [...assistantFiles] + const assistantCollections = ( + await getAssistantCollectionsByAssistantId(assistant.id) + ).collections + for (const collection of assistantCollections) { + const collectionFiles = ( + await getCollectionFilesByCollectionId(collection.id) + ).files + allFiles = [...allFiles, ...collectionFiles] + } + const assistantTools = (await getAssistantToolsByAssistantId(assistant.id)) + .tools + + setSelectedTools(assistantTools) + setChatFiles( + allFiles.map(file => ({ + id: file.id, + name: file.name, + type: file.type, + file: null + })) + ) + + if (allFiles.length > 0) setShowFilesDisplay(true) + } + + return { + handleInputChange, + handleSelectPrompt, + handleSelectUserFile, + handleSelectUserCollection, + handleSelectTool, + handleSelectAssistant + } +} diff --git a/chatdesk-ui/components/chat/chat-hooks/use-scroll.tsx b/chatdesk-ui/components/chat/chat-hooks/use-scroll.tsx new file mode 100644 index 0000000..9c6aea0 --- /dev/null +++ b/chatdesk-ui/components/chat/chat-hooks/use-scroll.tsx @@ -0,0 +1,87 @@ +import { ChatbotUIContext } from "@/context/context" +import { + type UIEventHandler, + useCallback, + useContext, + useEffect, + useRef, + useState +} from "react" + +export const useScroll = () => { + const { isGenerating, chatMessages } = useContext(ChatbotUIContext) + + const messagesStartRef = useRef(null) + const messagesEndRef = useRef(null) + const isAutoScrolling = useRef(false) + + const [isAtTop, setIsAtTop] = useState(false) + const [isAtBottom, setIsAtBottom] = useState(true) + const [userScrolled, setUserScrolled] = useState(false) + const [isOverflowing, setIsOverflowing] = useState(false) + + useEffect(() => { + setUserScrolled(false) + + if (!isGenerating && userScrolled) { + setUserScrolled(false) + } + }, [isGenerating]) + + useEffect(() => { + if (isGenerating && !userScrolled) { + scrollToBottom() + } + }, [chatMessages]) + + const handleScroll: UIEventHandler = useCallback(e => { + const target = e.target as HTMLDivElement + const bottom = + Math.round(target.scrollHeight) - Math.round(target.scrollTop) === + Math.round(target.clientHeight) + setIsAtBottom(bottom) + + const top = target.scrollTop === 0 + setIsAtTop(top) + + if (!bottom && !isAutoScrolling.current) { + setUserScrolled(true) + } else { + setUserScrolled(false) + } + + const isOverflow = target.scrollHeight > target.clientHeight + setIsOverflowing(isOverflow) + }, []) + + const scrollToTop = useCallback(() => { + if (messagesStartRef.current) { + messagesStartRef.current.scrollIntoView({ behavior: "instant" }) + } + }, []) + + const scrollToBottom = useCallback(() => { + isAutoScrolling.current = true + + setTimeout(() => { + if (messagesEndRef.current) { + messagesEndRef.current.scrollIntoView({ behavior: "instant" }) + } + + isAutoScrolling.current = false + }, 100) + }, []) + + return { + messagesStartRef, + messagesEndRef, + isAtTop, + isAtBottom, + userScrolled, + isOverflowing, + handleScroll, + scrollToTop, + scrollToBottom, + setIsAtBottom + } +} diff --git a/chatdesk-ui/components/chat/chat-hooks/use-select-file-handler.tsx b/chatdesk-ui/components/chat/chat-hooks/use-select-file-handler.tsx new file mode 100644 index 0000000..103ce6f --- /dev/null +++ b/chatdesk-ui/components/chat/chat-hooks/use-select-file-handler.tsx @@ -0,0 +1,204 @@ +import { ChatbotUIContext } from "@/context/context" +import { createDocXFile, createFile } from "@/db/files" +import { LLM_LIST } from "@/lib/models/llm/llm-list" +import mammoth from "mammoth" +import { useContext, useEffect, useState } from "react" +import { toast } from "sonner" + +export const ACCEPTED_FILE_TYPES = [ + "text/csv", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/json", + "text/markdown", + "application/pdf", + "text/plain" +].join(",") + +export const useSelectFileHandler = () => { + const { + selectedWorkspace, + profile, + chatSettings, + setNewMessageImages, + setNewMessageFiles, + setShowFilesDisplay, + setFiles, + setUseRetrieval + } = useContext(ChatbotUIContext) + + const [filesToAccept, setFilesToAccept] = useState(ACCEPTED_FILE_TYPES) + + useEffect(() => { + handleFilesToAccept() + }, [chatSettings?.model]) + + const handleFilesToAccept = () => { + const model = chatSettings?.model + const FULL_MODEL = LLM_LIST.find(llm => llm.modelId === model) + + if (!FULL_MODEL) return + + setFilesToAccept( + FULL_MODEL.imageInput + ? `${ACCEPTED_FILE_TYPES},image/*` + : ACCEPTED_FILE_TYPES + ) + } + + const handleSelectDeviceFile = async (file: File) => { + if (!profile || !selectedWorkspace || !chatSettings) return + + setShowFilesDisplay(true) + setUseRetrieval(true) + + if (file) { + let simplifiedFileType = file.type.split("/")[1] + + let reader = new FileReader() + + if (file.type.includes("image")) { + reader.readAsDataURL(file) + } else if (ACCEPTED_FILE_TYPES.split(",").includes(file.type)) { + if (simplifiedFileType.includes("vnd.adobe.pdf")) { + simplifiedFileType = "pdf" + } else if ( + simplifiedFileType.includes( + "vnd.openxmlformats-officedocument.wordprocessingml.document" || + "docx" + ) + ) { + simplifiedFileType = "docx" + } + + setNewMessageFiles(prev => [ + ...prev, + { + id: "loading", + name: file.name, + type: simplifiedFileType, + file: file + } + ]) + + // Handle docx files + if ( + file.type.includes( + "vnd.openxmlformats-officedocument.wordprocessingml.document" || + "docx" + ) + ) { + const arrayBuffer = await file.arrayBuffer() + const result = await mammoth.extractRawText({ + arrayBuffer + }) + + const createdFile = await createDocXFile( + result.value, + file, + { + user_id: profile.user_id, + description: "", + file_path: "", + name: file.name, + size: file.size, + tokens: 0, + type: simplifiedFileType + }, + selectedWorkspace.id, + chatSettings.embeddingsProvider + ) + + setFiles(prev => [...prev, createdFile]) + + setNewMessageFiles(prev => + prev.map(item => + item.id === "loading" + ? { + id: createdFile.id, + name: createdFile.name, + type: createdFile.type, + file: file + } + : item + ) + ) + + reader.onloadend = null + + return + } else { + // Use readAsArrayBuffer for PDFs and readAsText for other types + file.type.includes("pdf") + ? reader.readAsArrayBuffer(file) + : reader.readAsText(file) + } + } else { + throw new Error("Unsupported file type") + } + + reader.onloadend = async function () { + try { + if (file.type.includes("image")) { + // Create a temp url for the image file + const imageUrl = URL.createObjectURL(file) + + // This is a temporary image for display purposes in the chat input + setNewMessageImages(prev => [ + ...prev, + { + messageId: "temp", + path: "", + base64: reader.result, // base64 image + url: imageUrl, + file + } + ]) + } else { + const createdFile = await createFile( + file, + { + user_id: profile.user_id, + description: "", + file_path: "", + name: file.name, + size: file.size, + tokens: 0, + type: simplifiedFileType + }, + selectedWorkspace.id, + chatSettings.embeddingsProvider + ) + + setFiles(prev => [...prev, createdFile]) + + setNewMessageFiles(prev => + prev.map(item => + item.id === "loading" + ? { + id: createdFile.id, + name: createdFile.name, + type: createdFile.type, + file: file + } + : item + ) + ) + } + } catch (error: any) { + toast.error("Failed to upload. " + error?.message, { + duration: 10000 + }) + setNewMessageImages(prev => + prev.filter(img => img.messageId !== "temp") + ) + setNewMessageFiles(prev => prev.filter(file => file.id !== "loading")) + } + } + } + } + + return { + handleSelectDeviceFile, + filesToAccept + } +} diff --git a/chatdesk-ui/components/chat/chat-input.tsx b/chatdesk-ui/components/chat/chat-input.tsx new file mode 100644 index 0000000..a7aea91 --- /dev/null +++ b/chatdesk-ui/components/chat/chat-input.tsx @@ -0,0 +1,281 @@ +import { ChatbotUIContext } from "@/context/context" +import useHotkey from "@/lib/hooks/use-hotkey" +import { LLM_LIST } from "@/lib/models/llm/llm-list" +import { cn } from "@/lib/utils" +import { + IconBolt, + IconCirclePlus, + IconPlayerStopFilled, + IconSend +} from "@tabler/icons-react" +import Image from "next/image" +import { FC, useContext, useEffect, useRef, useState } from "react" +import { useTranslation } from "react-i18next" +import { toast } from "sonner" +import { Input } from "../ui/input" +import { TextareaAutosize } from "../ui/textarea-autosize" +import { ChatCommandInput } from "./chat-command-input" +import { ChatFilesDisplay } from "./chat-files-display" +import { useChatHandler } from "./chat-hooks/use-chat-handler" +import { useChatHistoryHandler } from "./chat-hooks/use-chat-history" +import { usePromptAndCommand } from "./chat-hooks/use-prompt-and-command" +import { useSelectFileHandler } from "./chat-hooks/use-select-file-handler" + +interface ChatInputProps {} + +export const ChatInput: FC = ({}) => { + const { t } = useTranslation() + + useHotkey("l", () => { + handleFocusChatInput() + }) + + const [isTyping, setIsTyping] = useState(false) + + const { + isAssistantPickerOpen, + focusAssistant, + setFocusAssistant, + userInput, + chatMessages, + isGenerating, + selectedPreset, + selectedAssistant, + focusPrompt, + setFocusPrompt, + focusFile, + focusTool, + setFocusTool, + isToolPickerOpen, + isPromptPickerOpen, + setIsPromptPickerOpen, + isFilePickerOpen, + setFocusFile, + chatSettings, + selectedTools, + setSelectedTools, + assistantImages + } = useContext(ChatbotUIContext) + + const { + chatInputRef, + handleSendMessage, + handleStopMessage, + handleFocusChatInput + } = useChatHandler() + + const { handleInputChange } = usePromptAndCommand() + + const { filesToAccept, handleSelectDeviceFile } = useSelectFileHandler() + + const { + setNewMessageContentToNextUserMessage, + setNewMessageContentToPreviousUserMessage + } = useChatHistoryHandler() + + const fileInputRef = useRef(null) + + useEffect(() => { + setTimeout(() => { + handleFocusChatInput() + }, 200) // FIX: hacky + }, [selectedPreset, selectedAssistant]) + + const handleKeyDown = (event: React.KeyboardEvent) => { + if (!isTyping && event.key === "Enter" && !event.shiftKey) { + event.preventDefault() + setIsPromptPickerOpen(false) + handleSendMessage(userInput, chatMessages, false) + } + + // Consolidate conditions to avoid TypeScript error + if ( + isPromptPickerOpen || + isFilePickerOpen || + isToolPickerOpen || + isAssistantPickerOpen + ) { + if ( + event.key === "Tab" || + event.key === "ArrowUp" || + event.key === "ArrowDown" + ) { + event.preventDefault() + // Toggle focus based on picker type + if (isPromptPickerOpen) setFocusPrompt(!focusPrompt) + if (isFilePickerOpen) setFocusFile(!focusFile) + if (isToolPickerOpen) setFocusTool(!focusTool) + if (isAssistantPickerOpen) setFocusAssistant(!focusAssistant) + } + } + + if (event.key === "ArrowUp" && event.shiftKey && event.ctrlKey) { + event.preventDefault() + setNewMessageContentToPreviousUserMessage() + } + + if (event.key === "ArrowDown" && event.shiftKey && event.ctrlKey) { + event.preventDefault() + setNewMessageContentToNextUserMessage() + } + + //use shift+ctrl+up and shift+ctrl+down to navigate through chat history + if (event.key === "ArrowUp" && event.shiftKey && event.ctrlKey) { + event.preventDefault() + setNewMessageContentToPreviousUserMessage() + } + + if (event.key === "ArrowDown" && event.shiftKey && event.ctrlKey) { + event.preventDefault() + setNewMessageContentToNextUserMessage() + } + + if ( + isAssistantPickerOpen && + (event.key === "Tab" || + event.key === "ArrowUp" || + event.key === "ArrowDown") + ) { + event.preventDefault() + setFocusAssistant(!focusAssistant) + } + } + + const handlePaste = (event: React.ClipboardEvent) => { + const imagesAllowed = LLM_LIST.find( + llm => llm.modelId === chatSettings?.model + )?.imageInput + + const items = event.clipboardData.items + for (const item of items) { + if (item.type.indexOf("image") === 0) { + if (!imagesAllowed) { + toast.error( + `Images are not supported for this model. Use models like GPT-4 Vision instead.` + ) + return + } + const file = item.getAsFile() + if (!file) return + handleSelectDeviceFile(file) + } + } + } + + return ( + <> +
+ + + {selectedTools && + selectedTools.map((tool, index) => ( +
+ setSelectedTools( + selectedTools.filter( + selectedTool => selectedTool.id !== tool.id + ) + ) + } + > +
+ + +
{tool.name}
+
+
+ ))} + + {selectedAssistant && ( +
+ {selectedAssistant.image_path && ( + img.path === selectedAssistant.image_path + )?.base64 + } + width={28} + height={28} + alt={selectedAssistant.name} + /> + )} + +
+ Talking to {selectedAssistant.name} +
+
+ )} +
+ +
+
+ +
+ + <> + fileInputRef.current?.click()} + /> + + {/* Hidden input to select files from device */} + { + if (!e.target.files) return + handleSelectDeviceFile(e.target.files[0]) + }} + accept={filesToAccept} + /> + + + setIsTyping(true)} + onCompositionEnd={() => setIsTyping(false)} + /> + +
+ {isGenerating ? ( + + ) : ( + { + if (!userInput) return + + handleSendMessage(userInput, chatMessages, false) + }} + size={30} + /> + )} +
+
+ + ) +} diff --git a/chatdesk-ui/components/chat/chat-messages.tsx b/chatdesk-ui/components/chat/chat-messages.tsx new file mode 100644 index 0000000..af13a88 --- /dev/null +++ b/chatdesk-ui/components/chat/chat-messages.tsx @@ -0,0 +1,38 @@ +import { useChatHandler } from "@/components/chat/chat-hooks/use-chat-handler" +import { ChatbotUIContext } from "@/context/context" +import { Tables } from "@/supabase/types" +import { FC, useContext, useState } from "react" +import { Message } from "../messages/message" + +interface ChatMessagesProps {} + +export const ChatMessages: FC = ({}) => { + const { chatMessages, chatFileItems } = useContext(ChatbotUIContext) + + const { handleSendEdit } = useChatHandler() + + const [editingMessage, setEditingMessage] = useState>() + + return chatMessages + .sort((a, b) => a.message.sequence_number - b.message.sequence_number) + .map((chatMessage, index, array) => { + const messageFileItems = chatFileItems.filter( + (chatFileItem, _, self) => + chatMessage.fileItems.includes(chatFileItem.id) && + self.findIndex(item => item.id === chatFileItem.id) === _ + ) + + return ( + setEditingMessage(undefined)} + onSubmitEdit={handleSendEdit} + /> + ) + }) +} diff --git a/chatdesk-ui/components/chat/chat-retrieval-settings.tsx b/chatdesk-ui/components/chat/chat-retrieval-settings.tsx new file mode 100644 index 0000000..2a47521 --- /dev/null +++ b/chatdesk-ui/components/chat/chat-retrieval-settings.tsx @@ -0,0 +1,65 @@ +import { ChatbotUIContext } from "@/context/context" +import { IconAdjustmentsHorizontal } from "@tabler/icons-react" +import { FC, useContext, useState } from "react" +import { Button } from "../ui/button" +import { + Dialog, + DialogContent, + DialogFooter, + DialogTrigger +} from "../ui/dialog" +import { Label } from "../ui/label" +import { Slider } from "../ui/slider" +import { WithTooltip } from "../ui/with-tooltip" + +interface ChatRetrievalSettingsProps {} + +export const ChatRetrievalSettings: FC = ({}) => { + const { sourceCount, setSourceCount } = useContext(ChatbotUIContext) + + const [isOpen, setIsOpen] = useState(false) + + return ( + + + Adjust retrieval settings.} + trigger={ + + } + /> + + + +
+ + + { + setSourceCount(values[0]) + }} + min={1} + max={10} + step={1} + /> +
+ + + + +
+
+ ) +} diff --git a/chatdesk-ui/components/chat/chat-scroll-buttons.tsx b/chatdesk-ui/components/chat/chat-scroll-buttons.tsx new file mode 100644 index 0000000..3eb6f2d --- /dev/null +++ b/chatdesk-ui/components/chat/chat-scroll-buttons.tsx @@ -0,0 +1,41 @@ +import { + IconCircleArrowDownFilled, + IconCircleArrowUpFilled +} from "@tabler/icons-react" +import { FC } from "react" + +interface ChatScrollButtonsProps { + isAtTop: boolean + isAtBottom: boolean + isOverflowing: boolean + scrollToTop: () => void + scrollToBottom: () => void +} + +export const ChatScrollButtons: FC = ({ + isAtTop, + isAtBottom, + isOverflowing, + scrollToTop, + scrollToBottom +}) => { + return ( + <> + {!isAtTop && isOverflowing && ( + + )} + + {!isAtBottom && isOverflowing && ( + + )} + + ) +} diff --git a/chatdesk-ui/components/chat/chat-secondary-buttons.tsx b/chatdesk-ui/components/chat/chat-secondary-buttons.tsx new file mode 100644 index 0000000..a52cec2 --- /dev/null +++ b/chatdesk-ui/components/chat/chat-secondary-buttons.tsx @@ -0,0 +1,82 @@ +import { useChatHandler } from "@/components/chat/chat-hooks/use-chat-handler" +import { ChatbotUIContext } from "@/context/context" +import { IconInfoCircle, IconMessagePlus } from "@tabler/icons-react" +import { FC, useContext } from "react" +import { WithTooltip } from "../ui/with-tooltip" +import { useTranslation } from 'react-i18next' + +interface ChatSecondaryButtonsProps {} + +export const ChatSecondaryButtons: FC = ({}) => { + + const { t } = useTranslation() + + const { selectedChat } = useContext(ChatbotUIContext) + + const { handleNewChat } = useChatHandler() + + return ( + <> + {selectedChat && ( + <> + +
{t("chatInfo.title")}
+ +
+
{t("chatInfo.model")}: {selectedChat.model}
+
{t("chatInfo.prompt")}: {selectedChat.prompt}
+ +
{t("chatInfo.temperature")}: {selectedChat.temperature}
+
{t("chatInfo.contextLength")}: {selectedChat.context_length}
+ +
+ {t("chatInfo.profileContext")}:{" "} + {selectedChat.include_profile_context + ? t("chatInfo.enabled") + : t("chatInfo.disabled")} +
+
+ {" "} + {t("chatInfo.workspaceInstructions")}:{" "} + {selectedChat.include_workspace_instructions + ? t("chatInfo.enabled") + : t("chatInfo.disabled")} +
+ +
+ {t("chatInfo.embeddingsProvider")}: {selectedChat.embeddings_provider} +
+
+ + } + trigger={ +
+ +
+ } + /> + + {t("chatInfo.startNewChat")}} + trigger={ +
+ +
+ } + /> + + )} + + ) +} diff --git a/chatdesk-ui/components/chat/chat-settings.tsx b/chatdesk-ui/components/chat/chat-settings.tsx new file mode 100644 index 0000000..8230d5f --- /dev/null +++ b/chatdesk-ui/components/chat/chat-settings.tsx @@ -0,0 +1,94 @@ +import { ChatbotUIContext } from "@/context/context" +import { CHAT_SETTING_LIMITS } from "@/lib/chat-setting-limits" +import useHotkey from "@/lib/hooks/use-hotkey" +import { LLMID, ModelProvider } from "@/types" +import { IconAdjustmentsHorizontal } from "@tabler/icons-react" +import { FC, useContext, useEffect, useRef } from "react" +import { Button } from "../ui/button" +import { ChatSettingsForm } from "../ui/chat-settings-form" +import { Popover, PopoverContent, PopoverTrigger } from "../ui/popover" + +interface ChatSettingsProps {} + +export const ChatSettings: FC = ({}) => { + useHotkey("i", () => handleClick()) + + const { + chatSettings, + setChatSettings, + models, + availableHostedModels, + availableLocalModels, + availableOpenRouterModels + } = useContext(ChatbotUIContext) + + const buttonRef = useRef(null) + + const handleClick = () => { + if (buttonRef.current) { + buttonRef.current.click() + } + } + + useEffect(() => { + if (!chatSettings) return + + setChatSettings({ + ...chatSettings, + temperature: Math.min( + chatSettings.temperature, + CHAT_SETTING_LIMITS[chatSettings.model]?.MAX_TEMPERATURE || 1 + ), + contextLength: Math.min( + chatSettings.contextLength, + CHAT_SETTING_LIMITS[chatSettings.model]?.MAX_CONTEXT_LENGTH || 4096 + ) + }) + }, [chatSettings?.model]) + + if (!chatSettings) return null + + const allModels = [ + ...models.map(model => ({ + modelId: model.model_id as LLMID, + modelName: model.name, + provider: "custom" as ModelProvider, + hostedId: model.id, + platformLink: "", + imageInput: false + })), + ...availableHostedModels, + ...availableLocalModels, + ...availableOpenRouterModels + ] + + const fullModel = allModels.find(llm => llm.modelId === chatSettings.model) + + return ( + + + + + + + + + + ) +} diff --git a/chatdesk-ui/components/chat/chat-ui.tsx b/chatdesk-ui/components/chat/chat-ui.tsx new file mode 100644 index 0000000..d57bcd0 --- /dev/null +++ b/chatdesk-ui/components/chat/chat-ui.tsx @@ -0,0 +1,234 @@ +import Loading from "@/app/[locale]/loading" +import { useChatHandler } from "@/components/chat/chat-hooks/use-chat-handler" +import { ChatbotUIContext } from "@/context/context" +import { getAssistantToolsByAssistantId } from "@/db/assistant-tools" +import { getChatFilesByChatId } from "@/db/chat-files" +import { getChatById } from "@/db/chats" +import { getMessageFileItemsByMessageId } from "@/db/message-file-items" +import { getMessagesByChatId } from "@/db/messages" +import { getMessageImageFromStorage } from "@/db/storage/message-images" +import { convertBlobToBase64 } from "@/lib/blob-to-b64" +import useHotkey from "@/lib/hooks/use-hotkey" +import { LLMID, MessageImage } from "@/types" +import { useParams } from "next/navigation" +import { FC, useContext, useEffect, useState } from "react" +import { ChatHelp } from "./chat-help" +import { useScroll } from "./chat-hooks/use-scroll" +import { ChatInput } from "./chat-input" +import { ChatMessages } from "./chat-messages" +import { ChatScrollButtons } from "./chat-scroll-buttons" +import { ChatSecondaryButtons } from "./chat-secondary-buttons" + +import { useTranslation } from "react-i18next" // 引入 useTranslation 进行国际化处理 + +interface ChatUIProps {} + +export const ChatUI: FC = ({}) => { + useHotkey("o", () => handleNewChat()) + + const { t } = useTranslation() // 使用 t 函数进行国际化 + + const params = useParams() + + const { + setChatMessages, + selectedChat, + setSelectedChat, + setChatSettings, + setChatImages, + assistants, + setSelectedAssistant, + setChatFileItems, + setChatFiles, + setShowFilesDisplay, + setUseRetrieval, + setSelectedTools + } = useContext(ChatbotUIContext) + + const { handleNewChat, handleFocusChatInput } = useChatHandler() + + const { + messagesStartRef, + messagesEndRef, + handleScroll, + scrollToBottom, + setIsAtBottom, + isAtTop, + isAtBottom, + isOverflowing, + scrollToTop + } = useScroll() + + const [loading, setLoading] = useState(true) + + useEffect(() => { + const fetchData = async () => { + await fetchMessages() + await fetchChat() + + scrollToBottom() + setIsAtBottom(true) + } + + if (params.chatid) { + fetchData().then(() => { + handleFocusChatInput() + setLoading(false) + }) + } else { + setLoading(false) + } + }, []) + + const fetchMessages = async () => { + const fetchedMessages = await getMessagesByChatId(params.chatid as string) + + const imagePromises: Promise[] = fetchedMessages.flatMap( + message => + message.image_paths + ? message.image_paths.map(async imagePath => { + const url = await getMessageImageFromStorage(imagePath) + + if (url) { + const response = await fetch(url) + const blob = await response.blob() + const base64 = await convertBlobToBase64(blob) + + return { + messageId: message.id, + path: imagePath, + base64, + url, + file: null + } + } + + return { + messageId: message.id, + path: imagePath, + base64: "", + url, + file: null + } + }) + : [] + ) + + const images: MessageImage[] = await Promise.all(imagePromises.flat()) + setChatImages(images) + + const messageFileItemPromises = fetchedMessages.map( + async message => await getMessageFileItemsByMessageId(message.id) + ) + + const messageFileItems = await Promise.all(messageFileItemPromises) + + const uniqueFileItems = messageFileItems.flatMap(item => item.file_items) + setChatFileItems(uniqueFileItems) + + const chatFiles = await getChatFilesByChatId(params.chatid as string) + + setChatFiles( + chatFiles.files.map(file => ({ + id: file.id, + name: file.name, + type: file.type, + file: null + })) + ) + + setUseRetrieval(true) + setShowFilesDisplay(true) + + const fetchedChatMessages = fetchedMessages.map(message => { + return { + message, + fileItems: messageFileItems + .filter(messageFileItem => messageFileItem.id === message.id) + .flatMap(messageFileItem => + messageFileItem.file_items.map(fileItem => fileItem.id) + ) + } + }) + + setChatMessages(fetchedChatMessages) + } + + const fetchChat = async () => { + const chat = await getChatById(params.chatid as string) + if (!chat) return + + if (chat.assistant_id) { + const assistant = assistants.find( + assistant => assistant.id === chat.assistant_id + ) + + if (assistant) { + setSelectedAssistant(assistant) + + const assistantTools = ( + await getAssistantToolsByAssistantId(assistant.id) + ).tools + setSelectedTools(assistantTools) + } + } + + setSelectedChat(chat) + setChatSettings({ + model: chat.model as LLMID, + prompt: chat.prompt, + temperature: chat.temperature, + contextLength: chat.context_length, + includeProfileContext: chat.include_profile_context, + includeWorkspaceInstructions: chat.include_workspace_instructions, + embeddingsProvider: chat.embeddings_provider as "openai" | "local" + }) + } + + if (loading) { + return + } + + return ( +
+
+ +
+ +
+ +
+ +
+
+ {selectedChat?.name || t("chat.defaultChatTitle")} +
+
+ +
+
+ + + +
+
+ +
+ +
+ +
+ +
+
+ ) +} diff --git a/chatdesk-ui/components/chat/file-picker.tsx b/chatdesk-ui/components/chat/file-picker.tsx new file mode 100644 index 0000000..00c4a73 --- /dev/null +++ b/chatdesk-ui/components/chat/file-picker.tsx @@ -0,0 +1,158 @@ +import { ChatbotUIContext } from "@/context/context" +import { Tables } from "@/supabase/types" +import { IconBooks } from "@tabler/icons-react" +import { FC, useContext, useEffect, useRef } from "react" +import { FileIcon } from "../ui/file-icon" + +interface FilePickerProps { + isOpen: boolean + searchQuery: string + onOpenChange: (isOpen: boolean) => void + selectedFileIds: string[] + selectedCollectionIds: string[] + onSelectFile: (file: Tables<"files">) => void + onSelectCollection: (collection: Tables<"collections">) => void + isFocused: boolean +} + +export const FilePicker: FC = ({ + isOpen, + searchQuery, + onOpenChange, + selectedFileIds, + selectedCollectionIds, + onSelectFile, + onSelectCollection, + isFocused +}) => { + const { files, collections, setIsFilePickerOpen } = + useContext(ChatbotUIContext) + + const itemsRef = useRef<(HTMLDivElement | null)[]>([]) + + useEffect(() => { + if (isFocused && itemsRef.current[0]) { + itemsRef.current[0].focus() + } + }, [isFocused]) + + const filteredFiles = files.filter( + file => + file.name.toLowerCase().includes(searchQuery.toLowerCase()) && + !selectedFileIds.includes(file.id) + ) + + const filteredCollections = collections.filter( + collection => + collection.name.toLowerCase().includes(searchQuery.toLowerCase()) && + !selectedCollectionIds.includes(collection.id) + ) + + const handleOpenChange = (isOpen: boolean) => { + onOpenChange(isOpen) + } + + const handleSelectFile = (file: Tables<"files">) => { + onSelectFile(file) + handleOpenChange(false) + } + + const handleSelectCollection = (collection: Tables<"collections">) => { + onSelectCollection(collection) + handleOpenChange(false) + } + + const getKeyDownHandler = + (index: number, type: "file" | "collection", item: any) => + (e: React.KeyboardEvent) => { + if (e.key === "Escape") { + e.preventDefault() + setIsFilePickerOpen(false) + } else if (e.key === "Backspace") { + e.preventDefault() + } else if (e.key === "Enter") { + e.preventDefault() + + if (type === "file") { + handleSelectFile(item) + } else { + handleSelectCollection(item) + } + } else if ( + (e.key === "Tab" || e.key === "ArrowDown") && + !e.shiftKey && + index === filteredFiles.length + filteredCollections.length - 1 + ) { + e.preventDefault() + itemsRef.current[0]?.focus() + } else if (e.key === "ArrowUp" && !e.shiftKey && index === 0) { + // go to last element if arrow up is pressed on first element + e.preventDefault() + itemsRef.current[itemsRef.current.length - 1]?.focus() + } else if (e.key === "ArrowUp") { + e.preventDefault() + const prevIndex = + index - 1 >= 0 ? index - 1 : itemsRef.current.length - 1 + itemsRef.current[prevIndex]?.focus() + } else if (e.key === "ArrowDown") { + e.preventDefault() + const nextIndex = index + 1 < itemsRef.current.length ? index + 1 : 0 + itemsRef.current[nextIndex]?.focus() + } + } + + return ( + <> + {isOpen && ( +
+ {filteredFiles.length === 0 && filteredCollections.length === 0 ? ( +
+ No matching files. +
+ ) : ( + <> + {[...filteredFiles, ...filteredCollections].map((item, index) => ( +
{ + itemsRef.current[index] = ref + }} + tabIndex={0} + className="hover:bg-accent focus:bg-accent flex cursor-pointer items-center rounded p-2 focus:outline-none" + onClick={() => { + if ("type" in item) { + handleSelectFile(item as Tables<"files">) + } else { + handleSelectCollection(item) + } + }} + onKeyDown={e => + getKeyDownHandler( + index, + "type" in item ? "file" : "collection", + item + )(e) + } + > + {"type" in item ? ( + ).type} size={32} /> + ) : ( + + )} + +
+
{item.name}
+ +
+ {item.description || "No description."} +
+
+
+ ))} + + )} +
+ )} + + ) +} diff --git a/chatdesk-ui/components/chat/prompt-picker.tsx b/chatdesk-ui/components/chat/prompt-picker.tsx new file mode 100644 index 0000000..55592e3 --- /dev/null +++ b/chatdesk-ui/components/chat/prompt-picker.tsx @@ -0,0 +1,215 @@ +import { ChatbotUIContext } from "@/context/context" +import { Tables } from "@/supabase/types" +import { FC, useContext, useEffect, useRef, useState } from "react" +import { Button } from "../ui/button" +import { Dialog, DialogContent, DialogHeader, DialogTitle } from "../ui/dialog" +import { Label } from "../ui/label" +import { TextareaAutosize } from "../ui/textarea-autosize" +import { usePromptAndCommand } from "./chat-hooks/use-prompt-and-command" + +interface PromptPickerProps {} + +export const PromptPicker: FC = ({}) => { + const { + prompts, + isPromptPickerOpen, + setIsPromptPickerOpen, + focusPrompt, + slashCommand + } = useContext(ChatbotUIContext) + + const { handleSelectPrompt } = usePromptAndCommand() + + const itemsRef = useRef<(HTMLDivElement | null)[]>([]) + + const [promptVariables, setPromptVariables] = useState< + { + promptId: string + name: string + value: string + }[] + >([]) + const [showPromptVariables, setShowPromptVariables] = useState(false) + + useEffect(() => { + if (focusPrompt && itemsRef.current[0]) { + itemsRef.current[0].focus() + } + }, [focusPrompt]) + + const [isTyping, setIsTyping] = useState(false) + + const filteredPrompts = prompts.filter(prompt => + prompt.name.toLowerCase().includes(slashCommand.toLowerCase()) + ) + + const handleOpenChange = (isOpen: boolean) => { + setIsPromptPickerOpen(isOpen) + } + + const callSelectPrompt = (prompt: Tables<"prompts">) => { + const regex = /\{\{.*?\}\}/g + const matches = prompt.content.match(regex) + + if (matches) { + const newPromptVariables = matches.map(match => ({ + promptId: prompt.id, + name: match.replace(/\{\{|\}\}/g, ""), + value: "" + })) + + setPromptVariables(newPromptVariables) + setShowPromptVariables(true) + } else { + handleSelectPrompt(prompt) + handleOpenChange(false) + } + } + + const getKeyDownHandler = + (index: number) => (e: React.KeyboardEvent) => { + if (e.key === "Backspace") { + e.preventDefault() + handleOpenChange(false) + } else if (e.key === "Enter") { + e.preventDefault() + callSelectPrompt(filteredPrompts[index]) + } else if ( + (e.key === "Tab" || e.key === "ArrowDown") && + !e.shiftKey && + index === filteredPrompts.length - 1 + ) { + e.preventDefault() + itemsRef.current[0]?.focus() + } else if (e.key === "ArrowUp" && !e.shiftKey && index === 0) { + // go to last element if arrow up is pressed on first element + e.preventDefault() + itemsRef.current[itemsRef.current.length - 1]?.focus() + } else if (e.key === "ArrowUp") { + e.preventDefault() + const prevIndex = + index - 1 >= 0 ? index - 1 : itemsRef.current.length - 1 + itemsRef.current[prevIndex]?.focus() + } else if (e.key === "ArrowDown") { + e.preventDefault() + const nextIndex = index + 1 < itemsRef.current.length ? index + 1 : 0 + itemsRef.current[nextIndex]?.focus() + } + } + + const handleSubmitPromptVariables = () => { + const newPromptContent = promptVariables.reduce( + (prevContent, variable) => + prevContent.replace( + new RegExp(`\\{\\{${variable.name}\\}\\}`, "g"), + variable.value + ), + prompts.find(prompt => prompt.id === promptVariables[0].promptId) + ?.content || "" + ) + + const newPrompt: any = { + ...prompts.find(prompt => prompt.id === promptVariables[0].promptId), + content: newPromptContent + } + + handleSelectPrompt(newPrompt) + handleOpenChange(false) + setShowPromptVariables(false) + setPromptVariables([]) + } + + const handleCancelPromptVariables = () => { + setShowPromptVariables(false) + setPromptVariables([]) + } + + const handleKeydownPromptVariables = ( + e: React.KeyboardEvent + ) => { + if (!isTyping && e.key === "Enter" && !e.shiftKey) { + e.preventDefault() + handleSubmitPromptVariables() + } + } + + return ( + <> + {isPromptPickerOpen && ( +
+ {showPromptVariables ? ( + + + + Enter Prompt Variables + + +
+ {promptVariables.map((variable, index) => ( +
+ + + { + const newPromptVariables = [...promptVariables] + newPromptVariables[index].value = value + setPromptVariables(newPromptVariables) + }} + minRows={3} + maxRows={5} + onCompositionStart={() => setIsTyping(true)} + onCompositionEnd={() => setIsTyping(false)} + /> +
+ ))} +
+ +
+ + + +
+
+
+ ) : filteredPrompts.length === 0 ? ( +
+ No matching prompts. +
+ ) : ( + filteredPrompts.map((prompt, index) => ( +
{ + itemsRef.current[index] = ref + }} + tabIndex={0} + className="hover:bg-accent focus:bg-accent flex cursor-pointer flex-col rounded p-2 focus:outline-none" + onClick={() => callSelectPrompt(prompt)} + onKeyDown={getKeyDownHandler(index)} + > +
{prompt.name}
+ +
+ {prompt.content} +
+
+ )) + )} +
+ )} + + ) +} diff --git a/chatdesk-ui/components/chat/quick-setting-option.tsx b/chatdesk-ui/components/chat/quick-setting-option.tsx new file mode 100644 index 0000000..6ddd48b --- /dev/null +++ b/chatdesk-ui/components/chat/quick-setting-option.tsx @@ -0,0 +1,71 @@ +import { LLM_LIST } from "@/lib/models/llm/llm-list" +import { Tables } from "@/supabase/types" +import { IconCircleCheckFilled, IconRobotFace } from "@tabler/icons-react" +import Image from "next/image" +import { FC } from "react" +import { ModelIcon } from "../models/model-icon" +import { DropdownMenuItem } from "../ui/dropdown-menu" + +interface QuickSettingOptionProps { + contentType: "presets" | "assistants" + isSelected: boolean + item: Tables<"presets"> | Tables<"assistants"> + onSelect: () => void + image: string +} + +export const QuickSettingOption: FC = ({ + contentType, + isSelected, + item, + onSelect, + image +}) => { + const modelDetails = LLM_LIST.find(model => model.modelId === item.model) + + return ( + +
+ {contentType === "presets" ? ( + + ) : image ? ( + Assistant + ) : ( + + )} +
+ +
+
{item.name}
+ + {item.description && ( +
{item.description}
+ )} +
+ +
+ {isSelected ? ( + + ) : null} +
+
+ ) +} diff --git a/chatdesk-ui/components/chat/quick-settings.tsx b/chatdesk-ui/components/chat/quick-settings.tsx new file mode 100644 index 0000000..9a3f178 --- /dev/null +++ b/chatdesk-ui/components/chat/quick-settings.tsx @@ -0,0 +1,308 @@ +import { ChatbotUIContext } from "@/context/context" +import { getAssistantCollectionsByAssistantId } from "@/db/assistant-collections" +import { getAssistantFilesByAssistantId } from "@/db/assistant-files" +import { getAssistantToolsByAssistantId } from "@/db/assistant-tools" +import { getCollectionFilesByCollectionId } from "@/db/collection-files" +import useHotkey from "@/lib/hooks/use-hotkey" +import { LLM_LIST } from "@/lib/models/llm/llm-list" +import { Tables } from "@/supabase/types" +import { LLMID } from "@/types" +import { IconChevronDown, IconRobotFace } from "@tabler/icons-react" +import Image from "next/image" +import { FC, useContext, useEffect, useRef, useState } from "react" +import { useTranslation } from "react-i18next" +import { ModelIcon } from "../models/model-icon" +import { Button } from "../ui/button" +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuTrigger +} from "../ui/dropdown-menu" +import { Input } from "../ui/input" +import { QuickSettingOption } from "./quick-setting-option" +import { set } from "date-fns" + +interface QuickSettingsProps {} + +export const QuickSettings: FC = ({}) => { + const { t } = useTranslation() + + useHotkey("p", () => setIsOpen(prevState => !prevState)) + + const { + presets, + assistants, + selectedAssistant, + selectedPreset, + chatSettings, + setSelectedPreset, + setSelectedAssistant, + setChatSettings, + assistantImages, + setChatFiles, + setSelectedTools, + setShowFilesDisplay, + selectedWorkspace + } = useContext(ChatbotUIContext) + + const inputRef = useRef(null) + + const [isOpen, setIsOpen] = useState(false) + const [search, setSearch] = useState("") + const [loading, setLoading] = useState(false) + + useEffect(() => { + if (isOpen) { + setTimeout(() => { + inputRef.current?.focus() + }, 100) // FIX: hacky + } + }, [isOpen]) + + const handleSelectQuickSetting = async ( + item: Tables<"presets"> | Tables<"assistants"> | null, + contentType: "presets" | "assistants" | "remove" + ) => { + console.log({ item, contentType }) + if (contentType === "assistants" && item) { + setSelectedAssistant(item as Tables<"assistants">) + setLoading(true) + let allFiles = [] + const assistantFiles = (await getAssistantFilesByAssistantId(item.id)) + .files + allFiles = [...assistantFiles] + const assistantCollections = ( + await getAssistantCollectionsByAssistantId(item.id) + ).collections + for (const collection of assistantCollections) { + const collectionFiles = ( + await getCollectionFilesByCollectionId(collection.id) + ).files + allFiles = [...allFiles, ...collectionFiles] + } + const assistantTools = (await getAssistantToolsByAssistantId(item.id)) + .tools + setSelectedTools(assistantTools) + setChatFiles( + allFiles.map(file => ({ + id: file.id, + name: file.name, + type: file.type, + file: null + })) + ) + if (allFiles.length > 0) setShowFilesDisplay(true) + setLoading(false) + setSelectedPreset(null) + } else if (contentType === "presets" && item) { + setSelectedPreset(item as Tables<"presets">) + setSelectedAssistant(null) + setChatFiles([]) + setSelectedTools([]) + } else { + setSelectedPreset(null) + setSelectedAssistant(null) + setChatFiles([]) + setSelectedTools([]) + if (selectedWorkspace) { + setChatSettings({ + model: selectedWorkspace.default_model as LLMID, + prompt: selectedWorkspace.default_prompt, + temperature: selectedWorkspace.default_temperature, + contextLength: selectedWorkspace.default_context_length, + includeProfileContext: selectedWorkspace.include_profile_context, + includeWorkspaceInstructions: + selectedWorkspace.include_workspace_instructions, + embeddingsProvider: selectedWorkspace.embeddings_provider as + | "openai" + | "local" + }) + } + return + } + + setChatSettings({ + model: item.model as LLMID, + prompt: item.prompt, + temperature: item.temperature, + contextLength: item.context_length, + includeProfileContext: item.include_profile_context, + includeWorkspaceInstructions: item.include_workspace_instructions, + embeddingsProvider: item.embeddings_provider as "openai" | "local" + }) + } + + const checkIfModified = () => { + if (!chatSettings) return false + + if (selectedPreset) { + return ( + selectedPreset.include_profile_context !== + chatSettings?.includeProfileContext || + selectedPreset.include_workspace_instructions !== + chatSettings.includeWorkspaceInstructions || + selectedPreset.context_length !== chatSettings.contextLength || + selectedPreset.model !== chatSettings.model || + selectedPreset.prompt !== chatSettings.prompt || + selectedPreset.temperature !== chatSettings.temperature + ) + } else if (selectedAssistant) { + return ( + selectedAssistant.include_profile_context !== + chatSettings.includeProfileContext || + selectedAssistant.include_workspace_instructions !== + chatSettings.includeWorkspaceInstructions || + selectedAssistant.context_length !== chatSettings.contextLength || + selectedAssistant.model !== chatSettings.model || + selectedAssistant.prompt !== chatSettings.prompt || + selectedAssistant.temperature !== chatSettings.temperature + ) + } + + return false + } + + const isModified = checkIfModified() + + const items = [ + ...presets.map(preset => ({ ...preset, contentType: "presets" })), + ...assistants.map(assistant => ({ + ...assistant, + contentType: "assistants" + })) + ] + + const selectedAssistantImage = selectedPreset + ? "" + : assistantImages.find( + image => image.path === selectedAssistant?.image_path + )?.base64 || "" + + const modelDetails = LLM_LIST.find( + model => model.modelId === selectedPreset?.model + ) + + return ( + { + setIsOpen(isOpen) + setSearch("") + }} + > + + + + + + {presets.length === 0 && assistants.length === 0 ? ( +
{t("chat.noItemsFound")}
+ ) : ( + <> + setSearch(e.target.value)} + onKeyDown={e => e.stopPropagation()} + /> + + {!!(selectedPreset || selectedAssistant) && ( + + | Tables<"assistants">) + } + onSelect={() => { + handleSelectQuickSetting(null, "remove") + }} + image={selectedPreset ? "" : selectedAssistantImage} + /> + )} + + {items + .filter( + item => + item.name.toLowerCase().includes(search.toLowerCase()) && + item.id !== selectedPreset?.id && + item.id !== selectedAssistant?.id + ) + .map(({ contentType, ...item }) => ( + + handleSelectQuickSetting( + item, + contentType as "presets" | "assistants" + ) + } + image={ + contentType === "assistants" + ? assistantImages.find( + image => + image.path === + (item as Tables<"assistants">).image_path + )?.base64 || "" + : "" + } + /> + ))} + + )} +
+
+ ) +} diff --git a/chatdesk-ui/components/chat/tool-picker.tsx b/chatdesk-ui/components/chat/tool-picker.tsx new file mode 100644 index 0000000..1088737 --- /dev/null +++ b/chatdesk-ui/components/chat/tool-picker.tsx @@ -0,0 +1,110 @@ +import { ChatbotUIContext } from "@/context/context" +import { Tables } from "@/supabase/types" +import { IconBolt } from "@tabler/icons-react" +import { FC, useContext, useEffect, useRef } from "react" +import { usePromptAndCommand } from "./chat-hooks/use-prompt-and-command" + +interface ToolPickerProps {} + +export const ToolPicker: FC = ({}) => { + const { + tools, + focusTool, + toolCommand, + isToolPickerOpen, + setIsToolPickerOpen + } = useContext(ChatbotUIContext) + + const { handleSelectTool } = usePromptAndCommand() + + const itemsRef = useRef<(HTMLDivElement | null)[]>([]) + + useEffect(() => { + if (focusTool && itemsRef.current[0]) { + itemsRef.current[0].focus() + } + }, [focusTool]) + + const filteredTools = tools.filter(tool => + tool.name.toLowerCase().includes(toolCommand.toLowerCase()) + ) + + const handleOpenChange = (isOpen: boolean) => { + setIsToolPickerOpen(isOpen) + } + + const callSelectTool = (tool: Tables<"tools">) => { + handleSelectTool(tool) + handleOpenChange(false) + } + + const getKeyDownHandler = + (index: number) => (e: React.KeyboardEvent) => { + if (e.key === "Backspace") { + e.preventDefault() + handleOpenChange(false) + } else if (e.key === "Enter") { + e.preventDefault() + callSelectTool(filteredTools[index]) + } else if ( + (e.key === "Tab" || e.key === "ArrowDown") && + !e.shiftKey && + index === filteredTools.length - 1 + ) { + e.preventDefault() + itemsRef.current[0]?.focus() + } else if (e.key === "ArrowUp" && !e.shiftKey && index === 0) { + // go to last element if arrow up is pressed on first element + e.preventDefault() + itemsRef.current[itemsRef.current.length - 1]?.focus() + } else if (e.key === "ArrowUp") { + e.preventDefault() + const prevIndex = + index - 1 >= 0 ? index - 1 : itemsRef.current.length - 1 + itemsRef.current[prevIndex]?.focus() + } else if (e.key === "ArrowDown") { + e.preventDefault() + const nextIndex = index + 1 < itemsRef.current.length ? index + 1 : 0 + itemsRef.current[nextIndex]?.focus() + } + } + + return ( + <> + {isToolPickerOpen && ( +
+ {filteredTools.length === 0 ? ( +
+ No matching tools. +
+ ) : ( + <> + {filteredTools.map((item, index) => ( +
{ + itemsRef.current[index] = ref + }} + tabIndex={0} + className="hover:bg-accent focus:bg-accent flex cursor-pointer items-center rounded p-2 focus:outline-none" + onClick={() => callSelectTool(item as Tables<"tools">)} + onKeyDown={getKeyDownHandler(index)} + > + + +
+
{item.name}
+ +
+ {item.description || "No description."} +
+
+
+ ))} + + )} +
+ )} + + ) +} diff --git a/chatdesk-ui/components/icons/anthropic-svg.tsx b/chatdesk-ui/components/icons/anthropic-svg.tsx new file mode 100644 index 0000000..27b2cd1 --- /dev/null +++ b/chatdesk-ui/components/icons/anthropic-svg.tsx @@ -0,0 +1,44 @@ +import { FC } from "react" + +interface AnthropicSVGProps { + height?: number + width?: number + className?: string +} + +export const AnthropicSVG: FC = ({ + height = 40, + width = 40, + className +}) => { + return ( + + + + + + + + + ) +} diff --git a/chatdesk-ui/components/icons/chatbotui-svg.tsx b/chatdesk-ui/components/icons/chatbotui-svg.tsx new file mode 100644 index 0000000..4c29cf6 --- /dev/null +++ b/chatdesk-ui/components/icons/chatbotui-svg.tsx @@ -0,0 +1,37 @@ +import { FC } from "react" + +interface ChatbotUISVGProps { + theme: "dark" | "light" + scale?: number +} + +export const ChatbotUISVG: FC = ({ theme, scale = 1 }) => { + return ( + + + + + + ) +} diff --git a/chatdesk-ui/components/icons/google-svg.tsx b/chatdesk-ui/components/icons/google-svg.tsx new file mode 100644 index 0000000..8a86709 --- /dev/null +++ b/chatdesk-ui/components/icons/google-svg.tsx @@ -0,0 +1,42 @@ +import { FC } from "react" + +interface GoogleSVGProps { + height?: number + width?: number + className?: string +} + +export const GoogleSVG: FC = ({ + height = 40, + width = 40, + className +}) => { + return ( + + + + + + + ) +} diff --git a/chatdesk-ui/components/icons/openai-svg.tsx b/chatdesk-ui/components/icons/openai-svg.tsx new file mode 100644 index 0000000..670c613 --- /dev/null +++ b/chatdesk-ui/components/icons/openai-svg.tsx @@ -0,0 +1,31 @@ +import { FC } from "react" + +interface OpenAISVGProps { + height?: number + width?: number + className?: string +} + +export const OpenAISVG: FC = ({ + height = 40, + width = 40, + className +}) => { + return ( + + + + ) +} diff --git a/chatdesk-ui/components/messages/message-actions.tsx b/chatdesk-ui/components/messages/message-actions.tsx new file mode 100644 index 0000000..0e1c8c7 --- /dev/null +++ b/chatdesk-ui/components/messages/message-actions.tsx @@ -0,0 +1,117 @@ +import { ChatbotUIContext } from "@/context/context" +import { IconCheck, IconCopy, IconEdit, IconRepeat } from "@tabler/icons-react" +import { FC, useContext, useEffect, useState } from "react" +import { WithTooltip } from "../ui/with-tooltip" + +export const MESSAGE_ICON_SIZE = 18 + +interface MessageActionsProps { + isAssistant: boolean + isLast: boolean + isEditing: boolean + isHovering: boolean + onCopy: () => void + onEdit: () => void + onRegenerate: () => void +} + +export const MessageActions: FC = ({ + isAssistant, + isLast, + isEditing, + isHovering, + onCopy, + onEdit, + onRegenerate +}) => { + const { isGenerating } = useContext(ChatbotUIContext) + + const [showCheckmark, setShowCheckmark] = useState(false) + + const handleCopy = () => { + onCopy() + setShowCheckmark(true) + } + + const handleForkChat = async () => {} + + useEffect(() => { + if (showCheckmark) { + const timer = setTimeout(() => { + setShowCheckmark(false) + }, 2000) + + return () => clearTimeout(timer) + } + }, [showCheckmark]) + + return (isLast && isGenerating) || isEditing ? null : ( +
+ {/* {((isAssistant && isHovering) || isLast) && ( + Fork Chat
} + trigger={ + + } + /> + )} */} + + {!isAssistant && isHovering && ( + Edit
} + trigger={ + + } + /> + )} + + {(isHovering || isLast) && ( + Copy
} + trigger={ + showCheckmark ? ( + + ) : ( + + ) + } + /> + )} + + {isLast && ( + Regenerate} + trigger={ + + } + /> + )} + + {/* {1 > 0 && isAssistant && } */} + + ) +} diff --git a/chatdesk-ui/components/messages/message-codeblock.tsx b/chatdesk-ui/components/messages/message-codeblock.tsx new file mode 100644 index 0000000..2b8d795 --- /dev/null +++ b/chatdesk-ui/components/messages/message-codeblock.tsx @@ -0,0 +1,135 @@ +import { Button } from "@/components/ui/button" +import { useCopyToClipboard } from "@/lib/hooks/use-copy-to-clipboard" +import { IconCheck, IconCopy, IconDownload } from "@tabler/icons-react" +import { FC, memo } from "react" +import { Prism as SyntaxHighlighter } from "react-syntax-highlighter" +import { oneDark } from "react-syntax-highlighter/dist/cjs/styles/prism" + +interface MessageCodeBlockProps { + language: string + value: string +} + +interface languageMap { + [key: string]: string | undefined +} + +export const programmingLanguages: languageMap = { + javascript: ".js", + python: ".py", + java: ".java", + c: ".c", + cpp: ".cpp", + "c++": ".cpp", + "c#": ".cs", + ruby: ".rb", + php: ".php", + swift: ".swift", + "objective-c": ".m", + kotlin: ".kt", + typescript: ".ts", + go: ".go", + perl: ".pl", + rust: ".rs", + scala: ".scala", + haskell: ".hs", + lua: ".lua", + shell: ".sh", + sql: ".sql", + html: ".html", + css: ".css" +} + +export const generateRandomString = (length: number, lowercase = false) => { + const chars = "ABCDEFGHJKLMNPQRSTUVWXY3456789" // excluding similar looking characters like Z, 2, I, 1, O, 0 + let result = "" + for (let i = 0; i < length; i++) { + result += chars.charAt(Math.floor(Math.random() * chars.length)) + } + return lowercase ? result.toLowerCase() : result +} + +export const MessageCodeBlock: FC = memo( + ({ language, value }) => { + const { isCopied, copyToClipboard } = useCopyToClipboard({ timeout: 2000 }) + + const downloadAsFile = () => { + if (typeof window === "undefined") { + return + } + const fileExtension = programmingLanguages[language] || ".file" + const suggestedFileName = `file-${generateRandomString( + 3, + true + )}${fileExtension}` + const fileName = window.prompt("Enter file name" || "", suggestedFileName) + + if (!fileName) { + return + } + + const blob = new Blob([value], { type: "text/plain" }) + const url = URL.createObjectURL(blob) + const link = document.createElement("a") + link.download = fileName + link.href = url + link.style.display = "none" + document.body.appendChild(link) + link.click() + document.body.removeChild(link) + URL.revokeObjectURL(url) + } + + const onCopy = () => { + if (isCopied) return + copyToClipboard(value) + } + + return ( +
+
+ {language} +
+ + + +
+
+ + {value} + +
+ ) + } +) + +MessageCodeBlock.displayName = "MessageCodeBlock" diff --git a/chatdesk-ui/components/messages/message-markdown-memoized.tsx b/chatdesk-ui/components/messages/message-markdown-memoized.tsx new file mode 100644 index 0000000..2fc2106 --- /dev/null +++ b/chatdesk-ui/components/messages/message-markdown-memoized.tsx @@ -0,0 +1,9 @@ +import { FC, memo } from "react" +import ReactMarkdown, { Options } from "react-markdown" + +export const MessageMarkdownMemoized: FC = memo( + ReactMarkdown, + (prevProps, nextProps) => + prevProps.children === nextProps.children && + prevProps.className === nextProps.className +) diff --git a/chatdesk-ui/components/messages/message-markdown.tsx b/chatdesk-ui/components/messages/message-markdown.tsx new file mode 100644 index 0000000..88be7e9 --- /dev/null +++ b/chatdesk-ui/components/messages/message-markdown.tsx @@ -0,0 +1,65 @@ +import React, { FC } from "react" +import remarkGfm from "remark-gfm" +import remarkMath from "remark-math" +import { MessageCodeBlock } from "./message-codeblock" +import { MessageMarkdownMemoized } from "./message-markdown-memoized" + +interface MessageMarkdownProps { + content: string +} + +export const MessageMarkdown: FC = ({ content }) => { + return ( + {children}

+ }, + img({ node, ...props }) { + return + }, + code({ node, className, children, ...props }) { + const childArray = React.Children.toArray(children) + const firstChild = childArray[0] as React.ReactElement + const firstChildAsString = React.isValidElement(firstChild) + ? (firstChild as React.ReactElement).props.children + : firstChild + + if (firstChildAsString === "▍") { + return + } + + if (typeof firstChildAsString === "string") { + childArray[0] = firstChildAsString.replace("`▍`", "▍") + } + + const match = /language-(\w+)/.exec(className || "") + + if ( + typeof firstChildAsString === "string" && + !firstChildAsString.includes("\n") + ) { + return ( + + {childArray} + + ) + } + + return ( + + ) + } + }} + > + {content} +
+ ) +} diff --git a/chatdesk-ui/components/messages/message-replies.tsx b/chatdesk-ui/components/messages/message-replies.tsx new file mode 100644 index 0000000..e9dd75b --- /dev/null +++ b/chatdesk-ui/components/messages/message-replies.tsx @@ -0,0 +1,51 @@ +import { IconMessage } from "@tabler/icons-react" +import { FC, useState } from "react" +import { + Sheet, + SheetContent, + SheetDescription, + SheetHeader, + SheetTitle, + SheetTrigger +} from "../ui/sheet" +import { WithTooltip } from "../ui/with-tooltip" +import { MESSAGE_ICON_SIZE } from "./message-actions" + +interface MessageRepliesProps {} + +export const MessageReplies: FC = ({}) => { + const [isOpen, setIsOpen] = useState(false) + + return ( + + + View Replies} + trigger={ +
setIsOpen(true)} + > + +
+ {1} +
+
+ } + /> +
+ + + + Are you sure absolutely sure? + + This action cannot be undone. This will permanently delete your + account and remove your data from our servers. + + + +
+ ) +} diff --git a/chatdesk-ui/components/messages/message.tsx b/chatdesk-ui/components/messages/message.tsx new file mode 100644 index 0000000..d0867d6 --- /dev/null +++ b/chatdesk-ui/components/messages/message.tsx @@ -0,0 +1,445 @@ +import { useChatHandler } from "@/components/chat/chat-hooks/use-chat-handler" +import { ChatbotUIContext } from "@/context/context" +import { LLM_LIST } from "@/lib/models/llm/llm-list" +import { cn } from "@/lib/utils" +import { Tables } from "@/supabase/types" +import { LLM, LLMID, MessageImage, ModelProvider } from "@/types" +import { + IconBolt, + IconCaretDownFilled, + IconCaretRightFilled, + IconCircleFilled, + IconFileText, + IconMoodSmile, + IconPencil +} from "@tabler/icons-react" +import Image from "next/image" +import { FC, useContext, useEffect, useRef, useState } from "react" +import { ModelIcon } from "../models/model-icon" +import { Button } from "../ui/button" +import { FileIcon } from "../ui/file-icon" +import { FilePreview } from "../ui/file-preview" +import { TextareaAutosize } from "../ui/textarea-autosize" +import { WithTooltip } from "../ui/with-tooltip" +import { MessageActions } from "./message-actions" +import { MessageMarkdown } from "./message-markdown" + +const ICON_SIZE = 32 + +interface MessageProps { + message: Tables<"messages"> + fileItems: Tables<"file_items">[] + isEditing: boolean + isLast: boolean + onStartEdit: (message: Tables<"messages">) => void + onCancelEdit: () => void + onSubmitEdit: (value: string, sequenceNumber: number) => void +} + +export const Message: FC = ({ + message, + fileItems, + isEditing, + isLast, + onStartEdit, + onCancelEdit, + onSubmitEdit +}) => { + const { + assistants, + profile, + isGenerating, + setIsGenerating, + firstTokenReceived, + availableLocalModels, + availableOpenRouterModels, + chatMessages, + selectedAssistant, + chatImages, + assistantImages, + toolInUse, + files, + models + } = useContext(ChatbotUIContext) + + const { handleSendMessage } = useChatHandler() + + const editInputRef = useRef(null) + + const [isHovering, setIsHovering] = useState(false) + const [editedMessage, setEditedMessage] = useState(message.content) + + const [showImagePreview, setShowImagePreview] = useState(false) + const [selectedImage, setSelectedImage] = useState(null) + + const [showFileItemPreview, setShowFileItemPreview] = useState(false) + const [selectedFileItem, setSelectedFileItem] = + useState | null>(null) + + const [viewSources, setViewSources] = useState(false) + + const handleCopy = () => { + if (navigator.clipboard) { + navigator.clipboard.writeText(message.content) + } else { + const textArea = document.createElement("textarea") + textArea.value = message.content + document.body.appendChild(textArea) + textArea.focus() + textArea.select() + document.execCommand("copy") + document.body.removeChild(textArea) + } + } + + const handleSendEdit = () => { + onSubmitEdit(editedMessage, message.sequence_number) + onCancelEdit() + } + + const handleKeyDown = (event: React.KeyboardEvent) => { + if (isEditing && event.key === "Enter" && event.metaKey) { + handleSendEdit() + } + } + + const handleRegenerate = async () => { + setIsGenerating(true) + await handleSendMessage( + editedMessage || chatMessages[chatMessages.length - 2].message.content, + chatMessages, + true + ) + } + + const handleStartEdit = () => { + onStartEdit(message) + } + + useEffect(() => { + setEditedMessage(message.content) + + if (isEditing && editInputRef.current) { + const input = editInputRef.current + input.focus() + input.setSelectionRange(input.value.length, input.value.length) + } + }, [isEditing]) + + const MODEL_DATA = [ + ...models.map(model => ({ + modelId: model.model_id as LLMID, + modelName: model.name, + provider: "custom" as ModelProvider, + hostedId: model.id, + platformLink: "", + imageInput: false + })), + ...LLM_LIST, + ...availableLocalModels, + ...availableOpenRouterModels + ].find(llm => llm.modelId === message.model) as LLM + + const messageAssistantImage = assistantImages.find( + image => image.assistantId === message.assistant_id + )?.base64 + + const selectedAssistantImage = assistantImages.find( + image => image.path === selectedAssistant?.image_path + )?.base64 + + const modelDetails = LLM_LIST.find(model => model.modelId === message.model) + + const fileAccumulator: Record< + string, + { + id: string + name: string + count: number + type: string + description: string + } + > = {} + + const fileSummary = fileItems.reduce((acc, fileItem) => { + const parentFile = files.find(file => file.id === fileItem.file_id) + if (parentFile) { + if (!acc[parentFile.id]) { + acc[parentFile.id] = { + id: parentFile.id, + name: parentFile.name, + count: 1, + type: parentFile.type, + description: parentFile.description + } + } else { + acc[parentFile.id].count += 1 + } + } + return acc + }, fileAccumulator) + + return ( +
setIsHovering(true)} + onMouseLeave={() => setIsHovering(false)} + onKeyDown={handleKeyDown} + > +
+
+ +
+
+ {message.role === "system" ? ( +
+ + +
Prompt
+
+ ) : ( +
+ {message.role === "assistant" ? ( + messageAssistantImage ? ( + assistant image + ) : ( + {MODEL_DATA?.modelName}
} + trigger={ + + } + /> + ) + ) : profile?.image_url ? ( + user image + ) : ( + + )} + +
+ {message.role === "assistant" + ? message.assistant_id + ? assistants.find( + assistant => assistant.id === message.assistant_id + )?.name + : selectedAssistant + ? selectedAssistant?.name + : MODEL_DATA?.modelName + : profile?.display_name ?? profile?.username} +
+
+ )} + {!firstTokenReceived && + isGenerating && + isLast && + message.role === "assistant" ? ( + <> + {(() => { + switch (toolInUse) { + case "none": + return ( + + ) + case "retrieval": + return ( +
+ + +
Searching files...
+
+ ) + default: + return ( +
+ + +
Using {toolInUse}...
+
+ ) + } + })()} + + ) : isEditing ? ( + + ) : ( + + )} +
+ + {fileItems.length > 0 && ( +
+ {!viewSources ? ( +
setViewSources(true)} + > + {fileItems.length} + {fileItems.length > 1 ? " Sources " : " Source "} + from {Object.keys(fileSummary).length}{" "} + {Object.keys(fileSummary).length > 1 ? "Files" : "File"}{" "} + +
+ ) : ( + <> +
setViewSources(false)} + > + {fileItems.length} + {fileItems.length > 1 ? " Sources " : " Source "} + from {Object.keys(fileSummary).length}{" "} + {Object.keys(fileSummary).length > 1 ? "Files" : "File"}{" "} + +
+ +
+ {Object.values(fileSummary).map((file, index) => ( +
+
+
+ +
+ +
{file.name}
+
+ + {fileItems + .filter(fileItem => { + const parentFile = files.find( + parentFile => parentFile.id === fileItem.file_id + ) + return parentFile?.id === file.id + }) + .map((fileItem, index) => ( +
{ + setSelectedFileItem(fileItem) + setShowFileItemPreview(true) + }} + > +
+ -{" "} + {fileItem.content.substring(0, 200)}... +
+
+ ))} +
+ ))} +
+ + )} +
+ )} + +
+ {message.image_paths.map((path, index) => { + const item = chatImages.find(image => image.path === path) + + return ( + message image { + setSelectedImage({ + messageId: message.id, + path, + base64: path.startsWith("data") ? path : item?.base64 || "", + url: path.startsWith("data") ? "" : item?.url || "", + file: null + }) + + setShowImagePreview(true) + }} + loading="lazy" + /> + ) + })} +
+ {isEditing && ( +
+ + + +
+ )} +
+ + {showImagePreview && selectedImage && ( + { + setShowImagePreview(isOpen) + setSelectedImage(null) + }} + /> + )} + + {showFileItemPreview && selectedFileItem && ( + { + setShowFileItemPreview(isOpen) + setSelectedFileItem(null) + }} + /> + )} + + ) +} diff --git a/chatdesk-ui/components/models/model-icon.tsx b/chatdesk-ui/components/models/model-icon.tsx new file mode 100644 index 0000000..27ca7b4 --- /dev/null +++ b/chatdesk-ui/components/models/model-icon.tsx @@ -0,0 +1,107 @@ +import { cn } from "@/lib/utils" +import mistral from "@/public/providers/mistral.png" +import groq from "@/public/providers/groq.png" +import perplexity from "@/public/providers/perplexity.png" +import { ModelProvider } from "@/types" +import { IconSparkles } from "@tabler/icons-react" +import { useTheme } from "next-themes" +import Image from "next/image" +import { FC, HTMLAttributes } from "react" +import { AnthropicSVG } from "../icons/anthropic-svg" +import { GoogleSVG } from "../icons/google-svg" +import { OpenAISVG } from "../icons/openai-svg" + +interface ModelIconProps extends HTMLAttributes { + provider: ModelProvider + height: number + width: number +} + +export const ModelIcon: FC = ({ + provider, + height, + width, + ...props +}) => { + const { theme } = useTheme() + + switch (provider as ModelProvider) { + case "openai": + return ( + + ) + case "mistral": + return ( + Mistral + ) + case "groq": + return ( + Groq + ) + case "anthropic": + return ( + + ) + case "google": + return ( + + ) + case "perplexity": + return ( + Mistral + ) + default: + return + } +} diff --git a/chatdesk-ui/components/models/model-option.tsx b/chatdesk-ui/components/models/model-option.tsx new file mode 100644 index 0000000..2344d3d --- /dev/null +++ b/chatdesk-ui/components/models/model-option.tsx @@ -0,0 +1,49 @@ +import { LLM } from "@/types" +import { FC } from "react" +import { ModelIcon } from "./model-icon" +import { IconInfoCircle } from "@tabler/icons-react" +import { WithTooltip } from "../ui/with-tooltip" + +interface ModelOptionProps { + model: LLM + onSelect: () => void +} + +export const ModelOption: FC = ({ model, onSelect }) => { + return ( + + {model.provider !== "ollama" && model.pricing && ( +
+
+ Input Cost:{" "} + {model.pricing.inputCost} {model.pricing.currency} per{" "} + {model.pricing.unit} +
+ {model.pricing.outputCost && ( +
+ Output Cost:{" "} + {model.pricing.outputCost} {model.pricing.currency} per{" "} + {model.pricing.unit} +
+ )} +
+ )} + + } + side="bottom" + trigger={ +
+
+ +
{model.modelName}
+
+
+ } + /> + ) +} diff --git a/chatdesk-ui/components/models/model-select.tsx b/chatdesk-ui/components/models/model-select.tsx new file mode 100644 index 0000000..1b7d271 --- /dev/null +++ b/chatdesk-ui/components/models/model-select.tsx @@ -0,0 +1,210 @@ +import { ChatbotUIContext } from "@/context/context" +import { LLM, LLMID, ModelProvider } from "@/types" +import { IconCheck, IconChevronDown } from "@tabler/icons-react" +import { FC, useContext, useEffect, useRef, useState } from "react" +import { Button } from "../ui/button" +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuTrigger +} from "../ui/dropdown-menu" +import { Input } from "../ui/input" +import { Tabs, TabsList, TabsTrigger } from "../ui/tabs" +import { ModelIcon } from "./model-icon" +import { ModelOption } from "./model-option" + +import { useTranslation } from 'react-i18next' + +interface ModelSelectProps { + selectedModelId: string + onSelectModel: (modelId: LLMID) => void +} + +export const ModelSelect: FC = ({ + selectedModelId, + onSelectModel +}) => { + + const { t } = useTranslation() + + const { + profile, + models, + availableHostedModels, + availableLocalModels, + availableOpenRouterModels + } = useContext(ChatbotUIContext) + + const inputRef = useRef(null) + const triggerRef = useRef(null) + + const [isOpen, setIsOpen] = useState(false) + const [search, setSearch] = useState("") + const [tab, setTab] = useState<"hosted" | "local">("hosted") + + useEffect(() => { + if (isOpen) { + setTimeout(() => { + inputRef.current?.focus() + }, 100) // FIX: hacky + } + }, [isOpen]) + + const handleSelectModel = (modelId: LLMID) => { + onSelectModel(modelId) + setIsOpen(false) + } + + const allModels = [ + ...models.map(model => ({ + modelId: model.model_id as LLMID, + modelName: model.name, + provider: "custom" as ModelProvider, + hostedId: model.id, + platformLink: "", + imageInput: false + })), + ...availableHostedModels, + ...availableLocalModels, + ...availableOpenRouterModels + ] + + const groupedModels = allModels.reduce>( + (groups, model) => { + const key = model.provider + if (!groups[key]) { + groups[key] = [] + } + groups[key].push(model) + return groups + }, + {} + ) + + const selectedModel = allModels.find( + model => model.modelId === selectedModelId + ) + + if (!profile) return null + + return ( + { + setIsOpen(isOpen) + setSearch("") + }} + > + + {allModels.length === 0 ? ( +
+ {t("chat.unlockModelsMessage")} +
+ ) : ( + + )} +
+ + + setTab(value)}> + {availableLocalModels.length > 0 && ( + + {t("chat.hosted")} + + {t("chat.local")} + + )} + + + setSearch(e.target.value)} + /> + +
+ {Object.entries(groupedModels).map(([provider, models]) => { + const filteredModels = models + .filter(model => { + if (tab === "hosted") return model.provider !== "ollama" + if (tab === "local") return model.provider === "ollama" + if (tab === "openrouter") return model.provider === "openrouter" + }) + .filter(model => + model.modelName.toLowerCase().includes(search.toLowerCase()) + ) + .sort((a, b) => a.provider.localeCompare(b.provider)) + + if (filteredModels.length === 0) return null + + return ( +
+
+ {provider === "openai" && profile.use_azure_openai + ? "AZURE OPENAI" + : provider === "custom" + ? t("modelProvider.custom") + : provider.toUpperCase()} +
+ +
+ {filteredModels.map(model => { + return ( +
+ {selectedModelId === model.modelId && ( + + )} + + handleSelectModel(model.modelId)} + /> +
+ ) + })} +
+
+ ) + })} +
+
+
+ ) +} diff --git a/chatdesk-ui/components/setup/api-step.tsx b/chatdesk-ui/components/setup/api-step.tsx new file mode 100644 index 0000000..2800194 --- /dev/null +++ b/chatdesk-ui/components/setup/api-step.tsx @@ -0,0 +1,246 @@ +import { Input } from "@/components/ui/input" +import { Label } from "@/components/ui/label" +import { FC } from "react" +import { Button } from "../ui/button" +import { useTranslation } from 'react-i18next' + +interface APIStepProps { + openaiAPIKey: string + openaiOrgID: string + azureOpenaiAPIKey: string + azureOpenaiEndpoint: string + azureOpenai35TurboID: string + azureOpenai45TurboID: string + azureOpenai45VisionID: string + azureOpenaiEmbeddingsID: string + anthropicAPIKey: string + googleGeminiAPIKey: string + mistralAPIKey: string + groqAPIKey: string + perplexityAPIKey: string + useAzureOpenai: boolean + openrouterAPIKey: string + onOpenrouterAPIKeyChange: (value: string) => void + onOpenaiAPIKeyChange: (value: string) => void + onOpenaiOrgIDChange: (value: string) => void + onAzureOpenaiAPIKeyChange: (value: string) => void + onAzureOpenaiEndpointChange: (value: string) => void + onAzureOpenai35TurboIDChange: (value: string) => void + onAzureOpenai45TurboIDChange: (value: string) => void + onAzureOpenai45VisionIDChange: (value: string) => void + onAzureOpenaiEmbeddingsIDChange: (value: string) => void + onAnthropicAPIKeyChange: (value: string) => void + onGoogleGeminiAPIKeyChange: (value: string) => void + onMistralAPIKeyChange: (value: string) => void + onGroqAPIKeyChange: (value: string) => void + onPerplexityAPIKeyChange: (value: string) => void + onUseAzureOpenaiChange: (value: boolean) => void +} + +export const APIStep: FC = ({ + openaiAPIKey, + openaiOrgID, + azureOpenaiAPIKey, + azureOpenaiEndpoint, + azureOpenai35TurboID, + azureOpenai45TurboID, + azureOpenai45VisionID, + azureOpenaiEmbeddingsID, + anthropicAPIKey, + googleGeminiAPIKey, + mistralAPIKey, + groqAPIKey, + perplexityAPIKey, + openrouterAPIKey, + useAzureOpenai, + onOpenaiAPIKeyChange, + onOpenaiOrgIDChange, + onAzureOpenaiAPIKeyChange, + onAzureOpenaiEndpointChange, + onAzureOpenai35TurboIDChange, + onAzureOpenai45TurboIDChange, + onAzureOpenai45VisionIDChange, + onAzureOpenaiEmbeddingsIDChange, + onAnthropicAPIKeyChange, + onGoogleGeminiAPIKeyChange, + onMistralAPIKeyChange, + onGroqAPIKeyChange, + onPerplexityAPIKeyChange, + onUseAzureOpenaiChange, + onOpenrouterAPIKeyChange +}) => { + const { t } = useTranslation() + + return ( + <> +
+ + + + useAzureOpenai + ? onAzureOpenaiAPIKeyChange(e.target.value) + : onOpenaiAPIKeyChange(e.target.value) + } + /> +
+ +
+ {useAzureOpenai ? ( + <> +
+ + + onAzureOpenaiEndpointChange(e.target.value)} + /> +
+ +
+ + + onAzureOpenai35TurboIDChange(e.target.value)} + /> +
+ +
+ + + onAzureOpenai45TurboIDChange(e.target.value)} + /> +
+ +
+ + + onAzureOpenai45VisionIDChange(e.target.value)} + /> +
+ +
+ + + onAzureOpenaiEmbeddingsIDChange(e.target.value)} + /> +
+ + ) : ( + <> +
+ + + onOpenaiOrgIDChange(e.target.value)} + /> +
+ + )} +
+ +
+ + + onAnthropicAPIKeyChange(e.target.value)} + /> +
+ +
+ + + onGoogleGeminiAPIKeyChange(e.target.value)} + /> +
+ +
+ + + onMistralAPIKeyChange(e.target.value)} + /> +
+ +
+ + + onGroqAPIKeyChange(e.target.value)} + /> +
+ +
+ + + onPerplexityAPIKeyChange(e.target.value)} + /> +
+
+ + + onOpenrouterAPIKeyChange(e.target.value)} + /> +
+ + ) +} diff --git a/chatdesk-ui/components/setup/finish-step.tsx b/chatdesk-ui/components/setup/finish-step.tsx new file mode 100644 index 0000000..52160d3 --- /dev/null +++ b/chatdesk-ui/components/setup/finish-step.tsx @@ -0,0 +1,20 @@ +import { FC } from "react" +import { useTranslation } from 'react-i18next' + +interface FinishStepProps { + displayName: string +} + +export const FinishStep: FC = ({ displayName }) => { + const { t } = useTranslation() + return ( +
+
+ {t("setup.WelcomeToChatAIUI")} + {displayName.length > 0 ? `, ${displayName.split(" ")[0]}` : null}! +
+ +
{t("setup.ClickNextToStartChatting")}
+
+ ) +} diff --git a/chatdesk-ui/components/setup/profile-step.tsx b/chatdesk-ui/components/setup/profile-step.tsx new file mode 100644 index 0000000..99f273e --- /dev/null +++ b/chatdesk-ui/components/setup/profile-step.tsx @@ -0,0 +1,154 @@ +import { useTranslation } from "react-i18next" // 导入 useTranslation + +import { Input } from "@/components/ui/input" +import { Label } from "@/components/ui/label" +import { + PROFILE_DISPLAY_NAME_MAX, + PROFILE_USERNAME_MAX, + PROFILE_USERNAME_MIN +} from "@/db/limits" +import { + IconCircleCheckFilled, + IconCircleXFilled, + IconLoader2 +} from "@tabler/icons-react" +import { FC, useCallback, useState } from "react" +import { LimitDisplay } from "../ui/limit-display" +import { toast } from "sonner" + +interface ProfileStepProps { + username: string + usernameAvailable: boolean + displayName: string + onUsernameAvailableChange: (isAvailable: boolean) => void + onUsernameChange: (username: string) => void + onDisplayNameChange: (name: string) => void +} + +export const ProfileStep: FC = ({ + username, + usernameAvailable, + displayName, + onUsernameAvailableChange, + onUsernameChange, + onDisplayNameChange +}) => { + const { t } = useTranslation() // 使用 t 函数来获取翻译内容 + + const [loading, setLoading] = useState(false) + + const debounce = (func: (...args: any[]) => void, wait: number) => { + let timeout: NodeJS.Timeout | null + + return (...args: any[]) => { + const later = () => { + if (timeout) clearTimeout(timeout) + func(...args) + } + + if (timeout) clearTimeout(timeout) + timeout = setTimeout(later, wait) + } + } + + const checkUsernameAvailability = useCallback( + debounce(async (username: string) => { + if (!username) return + + if (username.length < PROFILE_USERNAME_MIN) { + onUsernameAvailableChange(false) + return + } + + if (username.length > PROFILE_USERNAME_MAX) { + onUsernameAvailableChange(false) + return + } + + const usernameRegex = /^[a-zA-Z0-9_]+$/ + if (!usernameRegex.test(username)) { + onUsernameAvailableChange(false) + toast.error( + t("login.usernameError") + //"Username must be letters, numbers, or underscores only - no other characters or spacing allowed." + ) + return + } + + setLoading(true) + + const response = await fetch(`/api/username/available`, { + method: "POST", + body: JSON.stringify({ username }) + }) + + const data = await response.json() + const isAvailable = data.isAvailable + + onUsernameAvailableChange(isAvailable) + + setLoading(false) + }, 500), + [] + ) + + return ( + <> +
+
+ + +
+ {usernameAvailable ? ( +
{t("login.available")}
+ ) : ( +
{t("login.unavailable")}
+ )} +
+
+ +
+ { + onUsernameChange(e.target.value) + checkUsernameAvailability(e.target.value) + }} + minLength={PROFILE_USERNAME_MIN} + maxLength={PROFILE_USERNAME_MAX} + /> + +
+ {loading ? ( + + ) : usernameAvailable ? ( + + ) : ( + + )} +
+
+ + +
+ +
+ + + onDisplayNameChange(e.target.value)} + maxLength={PROFILE_DISPLAY_NAME_MAX} + /> + + +
+ + ) +} diff --git a/chatdesk-ui/components/setup/step-container.tsx b/chatdesk-ui/components/setup/step-container.tsx new file mode 100644 index 0000000..f2ad0e2 --- /dev/null +++ b/chatdesk-ui/components/setup/step-container.tsx @@ -0,0 +1,91 @@ +import { Button } from "@/components/ui/button" +import { + Card, + CardContent, + CardDescription, + CardFooter, + CardHeader, + CardTitle +} from "@/components/ui/card" +import { FC, useRef } from "react" + +import { useTranslation } from 'react-i18next' + +export const SETUP_STEP_COUNT = 3 + +interface StepContainerProps { + stepDescription: string + stepNum: number + stepTitle: string + onShouldProceed: (shouldProceed: boolean) => void + children?: React.ReactNode + showBackButton?: boolean + showNextButton?: boolean +} + +export const StepContainer: FC = ({ + stepDescription, + stepNum, + stepTitle, + onShouldProceed, + children, + showBackButton = false, + showNextButton = true +}) => { + const buttonRef = useRef(null) + + const handleKeyDown = (e: React.KeyboardEvent) => { + if (e.key === "Enter" && !e.shiftKey) { + if (buttonRef.current) { + buttonRef.current.click() + } + } + } + const { t } = useTranslation() + return ( + + + +
{stepTitle}
+ +
+ {stepNum} / {SETUP_STEP_COUNT} +
+
+ + {stepDescription} +
+ + {children} + + +
+ {showBackButton && ( + + )} +
+ +
+ {showNextButton && ( + + )} +
+
+
+ ) +} diff --git a/chatdesk-ui/components/sidebar/items/all/sidebar-create-item.tsx b/chatdesk-ui/components/sidebar/items/all/sidebar-create-item.tsx new file mode 100644 index 0000000..e089826 --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/all/sidebar-create-item.tsx @@ -0,0 +1,278 @@ +import { Button } from "@/components/ui/button" +import { + Sheet, + SheetContent, + SheetFooter, + SheetHeader, + SheetTitle +} from "@/components/ui/sheet" +import { ChatbotUIContext } from "@/context/context" +import { createAssistantCollections } from "@/db/assistant-collections" +import { createAssistantFiles } from "@/db/assistant-files" +import { createAssistantTools } from "@/db/assistant-tools" +import { createAssistant, updateAssistant } from "@/db/assistants" +import { createChat } from "@/db/chats" +import { createCollectionFiles } from "@/db/collection-files" +import { createCollection } from "@/db/collections" +import { createFileBasedOnExtension } from "@/db/files" +import { createModel } from "@/db/models" +import { createPreset } from "@/db/presets" +import { createPrompt } from "@/db/prompts" +import { + getAssistantImageFromStorage, + uploadAssistantImage +} from "@/db/storage/assistant-images" +import { createTool } from "@/db/tools" +import { convertBlobToBase64 } from "@/lib/blob-to-b64" +import { Tables, TablesInsert } from "@/supabase/types" +import { ContentType } from "@/types" +import { FC, useContext, useRef, useState } from "react" +import { toast } from "sonner" + +import { useTranslation } from 'react-i18next' + +interface SidebarCreateItemProps { + isOpen: boolean + isTyping: boolean + onOpenChange: (isOpen: boolean) => void + contentType: ContentType + renderInputs: () => JSX.Element + createState: any +} + +export const SidebarCreateItem: FC = ({ + isOpen, + onOpenChange, + contentType, + renderInputs, + createState, + isTyping +}) => { + + const { t, i18n } = useTranslation() + + const { + selectedWorkspace, + setChats, + setPresets, + setPrompts, + setFiles, + setCollections, + setAssistants, + setAssistantImages, + setTools, + setModels + } = useContext(ChatbotUIContext) + + const buttonRef = useRef(null) + + const [creating, setCreating] = useState(false) + + const createFunctions = { + chats: createChat, + presets: createPreset, + prompts: createPrompt, + files: async ( + createState: { file: File } & TablesInsert<"files">, + workspaceId: string + ) => { + if (!selectedWorkspace) return + + const { file, ...rest } = createState + + const createdFile = await createFileBasedOnExtension( + file, + rest, + workspaceId, + selectedWorkspace.embeddings_provider as "openai" | "local" + ) + + return createdFile + }, + collections: async ( + createState: { + image: File + collectionFiles: TablesInsert<"collection_files">[] + } & Tables<"collections">, + workspaceId: string + ) => { + const { collectionFiles, ...rest } = createState + + const createdCollection = await createCollection(rest, workspaceId) + + const finalCollectionFiles = collectionFiles.map(collectionFile => ({ + ...collectionFile, + collection_id: createdCollection.id + })) + + await createCollectionFiles(finalCollectionFiles) + + return createdCollection + }, + assistants: async ( + createState: { + image: File + files: Tables<"files">[] + collections: Tables<"collections">[] + tools: Tables<"tools">[] + } & Tables<"assistants">, + workspaceId: string + ) => { + const { image, files, collections, tools, ...rest } = createState + + const createdAssistant = await createAssistant(rest, workspaceId) + + let updatedAssistant = createdAssistant + + if (image) { + const filePath = await uploadAssistantImage(createdAssistant, image) + + updatedAssistant = await updateAssistant(createdAssistant.id, { + image_path: filePath + }) + + const url = (await getAssistantImageFromStorage(filePath)) || "" + + if (url) { + const response = await fetch(url) + const blob = await response.blob() + const base64 = await convertBlobToBase64(blob) + + setAssistantImages(prev => [ + ...prev, + { + assistantId: updatedAssistant.id, + path: filePath, + base64, + url + } + ]) + } + } + + const assistantFiles = files.map(file => ({ + user_id: rest.user_id, + assistant_id: createdAssistant.id, + file_id: file.id + })) + + const assistantCollections = collections.map(collection => ({ + user_id: rest.user_id, + assistant_id: createdAssistant.id, + collection_id: collection.id + })) + + const assistantTools = tools.map(tool => ({ + user_id: rest.user_id, + assistant_id: createdAssistant.id, + tool_id: tool.id + })) + + await createAssistantFiles(assistantFiles) + await createAssistantCollections(assistantCollections) + await createAssistantTools(assistantTools) + + return updatedAssistant + }, + tools: createTool, + models: createModel + } + + const stateUpdateFunctions = { + chats: setChats, + presets: setPresets, + prompts: setPrompts, + files: setFiles, + collections: setCollections, + assistants: setAssistants, + tools: setTools, + models: setModels + } + + const handleCreate = async () => { + try { + if (!selectedWorkspace) return + if (isTyping) return // Prevent creation while typing + + const createFunction = createFunctions[contentType] + const setStateFunction = stateUpdateFunctions[contentType] + + if (!createFunction || !setStateFunction) return + + setCreating(true) + + const newItem = await createFunction(createState, selectedWorkspace.id) + + setStateFunction((prevItems: any) => [...prevItems, newItem]) + + onOpenChange(false) + setCreating(false) + } catch (error) { + toast.error(`Error creating ${contentType.slice(0, -1)}. ${error}.`) + setCreating(false) + } + } + + const handleKeyDown = (e: React.KeyboardEvent) => { + if (!isTyping && e.key === "Enter" && !e.shiftKey) { + e.preventDefault() + buttonRef.current?.click() + } + } + +// 判断是否需要首字母大写(且做 -1 截断) +const needsUpperCaseFirstLetter = (language: string) => { + const languagesRequiringUpperCase = ['en', 'de', 'fr', 'es', 'it']; + return languagesRequiringUpperCase.includes(language); +}; + +// 处理翻译后的 contentType 文本 +const getCapitalizedContentType = (translated: string, language: string) => { + if (needsUpperCaseFirstLetter(language)) { + return translated.charAt(0).toUpperCase() + translated.slice(1, -1); // ✅ 按你的要求保留 .slice(1, -1) + } + return translated; +}; + + return ( + + +
+ + + {/* Create{" "} + {contentType.charAt(0).toUpperCase() + contentType.slice(1, -1)} */} + + {t("side.sidebarCreateNew")}{" "} + {getCapitalizedContentType(t(`contentType.${contentType}`), i18n.language)} + + + + + +
{renderInputs()}
+
+ + +
+ + + +
+
+
+
+ ) +} diff --git a/chatdesk-ui/components/sidebar/items/all/sidebar-delete-item.tsx b/chatdesk-ui/components/sidebar/items/all/sidebar-delete-item.tsx new file mode 100644 index 0000000..9e9344a --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/all/sidebar-delete-item.tsx @@ -0,0 +1,147 @@ +import { Button } from "@/components/ui/button" +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, + DialogTrigger +} from "@/components/ui/dialog" +import { ChatbotUIContext } from "@/context/context" +import { deleteAssistant } from "@/db/assistants" +import { deleteChat } from "@/db/chats" +import { deleteCollection } from "@/db/collections" +import { deleteFile } from "@/db/files" +import { deleteModel } from "@/db/models" +import { deletePreset } from "@/db/presets" +import { deletePrompt } from "@/db/prompts" +import { deleteFileFromStorage } from "@/db/storage/files" +import { deleteTool } from "@/db/tools" +import { Tables } from "@/supabase/types" +import { ContentType, DataItemType } from "@/types" +import { FC, useContext, useRef, useState } from "react" + +import { useTranslation } from 'react-i18next' + +interface SidebarDeleteItemProps { + item: DataItemType + contentType: ContentType +} + +export const SidebarDeleteItem: FC = ({ + item, + contentType +}) => { + + const { t } = useTranslation() + + const { + setChats, + setPresets, + setPrompts, + setFiles, + setCollections, + setAssistants, + setTools, + setModels + } = useContext(ChatbotUIContext) + + const buttonRef = useRef(null) + + const [showDialog, setShowDialog] = useState(false) + + const deleteFunctions = { + chats: async (chat: Tables<"chats">) => { + await deleteChat(chat.id) + }, + presets: async (preset: Tables<"presets">) => { + await deletePreset(preset.id) + }, + prompts: async (prompt: Tables<"prompts">) => { + await deletePrompt(prompt.id) + }, + files: async (file: Tables<"files">) => { + await deleteFileFromStorage(file.file_path) + await deleteFile(file.id) + }, + collections: async (collection: Tables<"collections">) => { + await deleteCollection(collection.id) + }, + assistants: async (assistant: Tables<"assistants">) => { + await deleteAssistant(assistant.id) + setChats(prevState => + prevState.filter(chat => chat.assistant_id !== assistant.id) + ) + }, + tools: async (tool: Tables<"tools">) => { + await deleteTool(tool.id) + }, + models: async (model: Tables<"models">) => { + await deleteModel(model.id) + } + } + + const stateUpdateFunctions = { + chats: setChats, + presets: setPresets, + prompts: setPrompts, + files: setFiles, + collections: setCollections, + assistants: setAssistants, + tools: setTools, + models: setModels + } + + const handleDelete = async () => { + const deleteFunction = deleteFunctions[contentType] + const setStateFunction = stateUpdateFunctions[contentType] + + if (!deleteFunction || !setStateFunction) return + + await deleteFunction(item as any) + + setStateFunction((prevItems: any) => + prevItems.filter((prevItem: any) => prevItem.id !== item.id) + ) + + setShowDialog(false) + } + + const handleKeyDown = (e: React.KeyboardEvent) => { + if (e.key === "Enter") { + e.stopPropagation() + buttonRef.current?.click() + } + } + + return ( + + + + + + + + {t("side.delete")} {contentType.slice(0, -1)} + + + {t("side.confirmDelete")} {item.name}? + + + + + + + + + + + ) +} diff --git a/chatdesk-ui/components/sidebar/items/all/sidebar-display-item.tsx b/chatdesk-ui/components/sidebar/items/all/sidebar-display-item.tsx new file mode 100644 index 0000000..792ddd7 --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/all/sidebar-display-item.tsx @@ -0,0 +1,151 @@ +import { ChatbotUIContext } from "@/context/context" +import { createChat } from "@/db/chats" +import { cn } from "@/lib/utils" +import { Tables } from "@/supabase/types" +import { ContentType, DataItemType } from "@/types" +import { useRouter } from "next/navigation" +import { FC, useContext, useRef, useState } from "react" +import { SidebarUpdateItem } from "./sidebar-update-item" + +import { usePathname } from "next/navigation" +import i18nConfig from "@/i18nConfig" + +interface SidebarItemProps { + item: DataItemType + isTyping: boolean + contentType: ContentType + icon: React.ReactNode + updateState: any + renderInputs: (renderState: any) => JSX.Element +} + +export const SidebarItem: FC = ({ + item, + contentType, + updateState, + renderInputs, + icon, + isTyping +}) => { + const { selectedWorkspace, setChats, setSelectedAssistant } = + useContext(ChatbotUIContext) + + + + const pathname = usePathname() // 获取当前路径 + const pathSegments = pathname.split("/").filter(Boolean) + const locales = i18nConfig.locales + const defaultLocale = i18nConfig.defaultLocale + + let locale: (typeof locales)[number] = defaultLocale + + const segment = pathSegments[0] as (typeof locales)[number] + + if (locales.includes(segment)) { + locale = segment + } + const homePath = locale === defaultLocale ? "/" : `/${locale}` + + + + + + const router = useRouter() + + const itemRef = useRef(null) + + const [isHovering, setIsHovering] = useState(false) + + const actionMap = { + chats: async (item: any) => {}, + presets: async (item: any) => {}, + prompts: async (item: any) => {}, + files: async (item: any) => {}, + collections: async (item: any) => {}, + assistants: async (assistant: Tables<"assistants">) => { + if (!selectedWorkspace) return + + const createdChat = await createChat({ + user_id: assistant.user_id, + workspace_id: selectedWorkspace.id, + assistant_id: assistant.id, + context_length: assistant.context_length, + include_profile_context: assistant.include_profile_context, + include_workspace_instructions: + assistant.include_workspace_instructions, + model: assistant.model, + name: `Chat with ${assistant.name}`, + prompt: assistant.prompt, + temperature: assistant.temperature, + embeddings_provider: assistant.embeddings_provider + }) + + setChats(prevState => [createdChat, ...prevState]) + setSelectedAssistant(assistant) + + return router.push(`${homePath}/${selectedWorkspace.id}/chat/${createdChat.id}`) + // return router.push(`/${locale}/${selectedWorkspace.id}/chat/${createdChat.id}`) + }, + tools: async (item: any) => {}, + models: async (item: any) => {} + } + + const handleKeyDown = (e: React.KeyboardEvent) => { + if (e.key === "Enter") { + e.stopPropagation() + itemRef.current?.click() + } + } + + // const handleClickAction = async ( + // e: React.MouseEvent + // ) => { + // e.stopPropagation() + + // const action = actionMap[contentType] + + // await action(item as any) + // } + + return ( + +
setIsHovering(true)} + onMouseLeave={() => setIsHovering(false)} + > + {icon} + +
+ {item.name} +
+ + {/* TODO */} + {/* {isHovering && ( + Start chat with {contentType.slice(0, -1)}
} + trigger={ + + } + /> + )} */} + +
+ ) +} diff --git a/chatdesk-ui/components/sidebar/items/all/sidebar-update-item.tsx b/chatdesk-ui/components/sidebar/items/all/sidebar-update-item.tsx new file mode 100644 index 0000000..08d339f --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/all/sidebar-update-item.tsx @@ -0,0 +1,684 @@ +import { Button } from "@/components/ui/button" +import { Label } from "@/components/ui/label" +import { + Sheet, + SheetContent, + SheetFooter, + SheetHeader, + SheetTitle, + SheetTrigger +} from "@/components/ui/sheet" +import { AssignWorkspaces } from "@/components/workspace/assign-workspaces" +import { ChatbotUIContext } from "@/context/context" +import { + createAssistantCollection, + deleteAssistantCollection, + getAssistantCollectionsByAssistantId +} from "@/db/assistant-collections" +import { + createAssistantFile, + deleteAssistantFile, + getAssistantFilesByAssistantId +} from "@/db/assistant-files" +import { + createAssistantTool, + deleteAssistantTool, + getAssistantToolsByAssistantId +} from "@/db/assistant-tools" +import { + createAssistantWorkspaces, + deleteAssistantWorkspace, + getAssistantWorkspacesByAssistantId, + updateAssistant +} from "@/db/assistants" +import { updateChat } from "@/db/chats" +import { + createCollectionFile, + deleteCollectionFile, + getCollectionFilesByCollectionId +} from "@/db/collection-files" +import { + createCollectionWorkspaces, + deleteCollectionWorkspace, + getCollectionWorkspacesByCollectionId, + updateCollection +} from "@/db/collections" +import { + createFileWorkspaces, + deleteFileWorkspace, + getFileWorkspacesByFileId, + updateFile +} from "@/db/files" +import { + createModelWorkspaces, + deleteModelWorkspace, + getModelWorkspacesByModelId, + updateModel +} from "@/db/models" +import { + createPresetWorkspaces, + deletePresetWorkspace, + getPresetWorkspacesByPresetId, + updatePreset +} from "@/db/presets" +import { + createPromptWorkspaces, + deletePromptWorkspace, + getPromptWorkspacesByPromptId, + updatePrompt +} from "@/db/prompts" +import { + getAssistantImageFromStorage, + uploadAssistantImage +} from "@/db/storage/assistant-images" +import { + createToolWorkspaces, + deleteToolWorkspace, + getToolWorkspacesByToolId, + updateTool +} from "@/db/tools" +import { convertBlobToBase64 } from "@/lib/blob-to-b64" +import { Tables, TablesUpdate } from "@/supabase/types" +import { CollectionFile, ContentType, DataItemType } from "@/types" +import { FC, useContext, useEffect, useRef, useState } from "react" +import profile from "react-syntax-highlighter/dist/esm/languages/hljs/profile" +import { toast } from "sonner" +import { SidebarDeleteItem } from "./sidebar-delete-item" + +import { useTranslation } from 'react-i18next' + +interface SidebarUpdateItemProps { + isTyping: boolean + item: DataItemType + contentType: ContentType + children: React.ReactNode + renderInputs: (renderState: any) => JSX.Element + updateState: any +} + +export const SidebarUpdateItem: FC = ({ + item, + contentType, + children, + renderInputs, + updateState, + isTyping +}) => { + const { t } = useTranslation() + + const { + workspaces, + selectedWorkspace, + setChats, + setPresets, + setPrompts, + setFiles, + setCollections, + setAssistants, + setTools, + setModels, + setAssistantImages + } = useContext(ChatbotUIContext) + + const buttonRef = useRef(null) + + const [isOpen, setIsOpen] = useState(false) + const [startingWorkspaces, setStartingWorkspaces] = useState< + Tables<"workspaces">[] + >([]) + const [selectedWorkspaces, setSelectedWorkspaces] = useState< + Tables<"workspaces">[] + >([]) + + // Collections Render State + const [startingCollectionFiles, setStartingCollectionFiles] = useState< + CollectionFile[] + >([]) + const [selectedCollectionFiles, setSelectedCollectionFiles] = useState< + CollectionFile[] + >([]) + + // Assistants Render State + const [startingAssistantFiles, setStartingAssistantFiles] = useState< + Tables<"files">[] + >([]) + const [startingAssistantCollections, setStartingAssistantCollections] = + useState[]>([]) + const [startingAssistantTools, setStartingAssistantTools] = useState< + Tables<"tools">[] + >([]) + const [selectedAssistantFiles, setSelectedAssistantFiles] = useState< + Tables<"files">[] + >([]) + const [selectedAssistantCollections, setSelectedAssistantCollections] = + useState[]>([]) + const [selectedAssistantTools, setSelectedAssistantTools] = useState< + Tables<"tools">[] + >([]) + + useEffect(() => { + if (isOpen) { + const fetchData = async () => { + if (workspaces.length > 1) { + const workspaces = await fetchSelectedWorkspaces() + setStartingWorkspaces(workspaces) + setSelectedWorkspaces(workspaces) + } + + const fetchDataFunction = fetchDataFunctions[contentType] + if (!fetchDataFunction) return + await fetchDataFunction(item.id) + } + + fetchData() + } + }, [isOpen]) + + const renderState = { + chats: null, + presets: null, + prompts: null, + files: null, + collections: { + startingCollectionFiles, + setStartingCollectionFiles, + selectedCollectionFiles, + setSelectedCollectionFiles + }, + assistants: { + startingAssistantFiles, + setStartingAssistantFiles, + startingAssistantCollections, + setStartingAssistantCollections, + startingAssistantTools, + setStartingAssistantTools, + selectedAssistantFiles, + setSelectedAssistantFiles, + selectedAssistantCollections, + setSelectedAssistantCollections, + selectedAssistantTools, + setSelectedAssistantTools + }, + tools: null, + models: null + } + + const fetchDataFunctions = { + chats: null, + presets: null, + prompts: null, + files: null, + collections: async (collectionId: string) => { + const collectionFiles = + await getCollectionFilesByCollectionId(collectionId) + setStartingCollectionFiles(collectionFiles.files) + setSelectedCollectionFiles([]) + }, + assistants: async (assistantId: string) => { + const assistantFiles = await getAssistantFilesByAssistantId(assistantId) + setStartingAssistantFiles(assistantFiles.files) + + const assistantCollections = + await getAssistantCollectionsByAssistantId(assistantId) + setStartingAssistantCollections(assistantCollections.collections) + + const assistantTools = await getAssistantToolsByAssistantId(assistantId) + setStartingAssistantTools(assistantTools.tools) + + setSelectedAssistantFiles([]) + setSelectedAssistantCollections([]) + setSelectedAssistantTools([]) + }, + tools: null, + models: null + } + + const fetchWorkpaceFunctions = { + chats: null, + presets: async (presetId: string) => { + const item = await getPresetWorkspacesByPresetId(presetId) + return item.workspaces + }, + prompts: async (promptId: string) => { + const item = await getPromptWorkspacesByPromptId(promptId) + return item.workspaces + }, + files: async (fileId: string) => { + const item = await getFileWorkspacesByFileId(fileId) + return item.workspaces + }, + collections: async (collectionId: string) => { + const item = await getCollectionWorkspacesByCollectionId(collectionId) + return item.workspaces + }, + assistants: async (assistantId: string) => { + const item = await getAssistantWorkspacesByAssistantId(assistantId) + return item.workspaces + }, + tools: async (toolId: string) => { + const item = await getToolWorkspacesByToolId(toolId) + return item.workspaces + }, + models: async (modelId: string) => { + const item = await getModelWorkspacesByModelId(modelId) + return item.workspaces + } + } + + const fetchSelectedWorkspaces = async () => { + const fetchFunction = fetchWorkpaceFunctions[contentType] + + if (!fetchFunction) return [] + + const workspaces = await fetchFunction(item.id) + + return workspaces + } + + const handleWorkspaceUpdates = async ( + startingWorkspaces: Tables<"workspaces">[], + selectedWorkspaces: Tables<"workspaces">[], + itemId: string, + deleteWorkspaceFn: ( + itemId: string, + workspaceId: string + ) => Promise, + createWorkspaceFn: ( + workspaces: { user_id: string; item_id: string; workspace_id: string }[] + ) => Promise, + itemIdKey: string + ) => { + if (!selectedWorkspace) return + + const deleteList = startingWorkspaces.filter( + startingWorkspace => + !selectedWorkspaces.some( + selectedWorkspace => selectedWorkspace.id === startingWorkspace.id + ) + ) + + for (const workspace of deleteList) { + await deleteWorkspaceFn(itemId, workspace.id) + } + + if (deleteList.map(w => w.id).includes(selectedWorkspace.id)) { + const setStateFunction = stateUpdateFunctions[contentType] + + if (setStateFunction) { + setStateFunction((prevItems: any) => + prevItems.filter((prevItem: any) => prevItem.id !== item.id) + ) + } + } + + const createList = selectedWorkspaces.filter( + selectedWorkspace => + !startingWorkspaces.some( + startingWorkspace => startingWorkspace.id === selectedWorkspace.id + ) + ) + + await createWorkspaceFn( + createList.map(workspace => { + return { + user_id: workspace.user_id, + [itemIdKey]: itemId, + workspace_id: workspace.id + } as any + }) + ) + } + + const updateFunctions = { + chats: updateChat, + presets: async (presetId: string, updateState: TablesUpdate<"presets">) => { + const updatedPreset = await updatePreset(presetId, updateState) + + await handleWorkspaceUpdates( + startingWorkspaces, + selectedWorkspaces, + presetId, + deletePresetWorkspace, + createPresetWorkspaces as any, + "preset_id" + ) + + return updatedPreset + }, + prompts: async (promptId: string, updateState: TablesUpdate<"prompts">) => { + const updatedPrompt = await updatePrompt(promptId, updateState) + + await handleWorkspaceUpdates( + startingWorkspaces, + selectedWorkspaces, + promptId, + deletePromptWorkspace, + createPromptWorkspaces as any, + "prompt_id" + ) + + return updatedPrompt + }, + files: async (fileId: string, updateState: TablesUpdate<"files">) => { + const updatedFile = await updateFile(fileId, updateState) + + await handleWorkspaceUpdates( + startingWorkspaces, + selectedWorkspaces, + fileId, + deleteFileWorkspace, + createFileWorkspaces as any, + "file_id" + ) + + return updatedFile + }, + collections: async ( + collectionId: string, + updateState: TablesUpdate<"assistants"> + ) => { + if (!profile) return + + const { ...rest } = updateState + + const filesToAdd = selectedCollectionFiles.filter( + selectedFile => + !startingCollectionFiles.some( + startingFile => startingFile.id === selectedFile.id + ) + ) + + const filesToRemove = startingCollectionFiles.filter(startingFile => + selectedCollectionFiles.some( + selectedFile => selectedFile.id === startingFile.id + ) + ) + + for (const file of filesToAdd) { + await createCollectionFile({ + user_id: item.user_id, + collection_id: collectionId, + file_id: file.id + }) + } + + for (const file of filesToRemove) { + await deleteCollectionFile(collectionId, file.id) + } + + const updatedCollection = await updateCollection(collectionId, rest) + + await handleWorkspaceUpdates( + startingWorkspaces, + selectedWorkspaces, + collectionId, + deleteCollectionWorkspace, + createCollectionWorkspaces as any, + "collection_id" + ) + + return updatedCollection + }, + assistants: async ( + assistantId: string, + updateState: { + assistantId: string + image: File + } & TablesUpdate<"assistants"> + ) => { + const { image, ...rest } = updateState + + const filesToAdd = selectedAssistantFiles.filter( + selectedFile => + !startingAssistantFiles.some( + startingFile => startingFile.id === selectedFile.id + ) + ) + + const filesToRemove = startingAssistantFiles.filter(startingFile => + selectedAssistantFiles.some( + selectedFile => selectedFile.id === startingFile.id + ) + ) + + for (const file of filesToAdd) { + await createAssistantFile({ + user_id: item.user_id, + assistant_id: assistantId, + file_id: file.id + }) + } + + for (const file of filesToRemove) { + await deleteAssistantFile(assistantId, file.id) + } + + const collectionsToAdd = selectedAssistantCollections.filter( + selectedCollection => + !startingAssistantCollections.some( + startingCollection => + startingCollection.id === selectedCollection.id + ) + ) + + const collectionsToRemove = startingAssistantCollections.filter( + startingCollection => + selectedAssistantCollections.some( + selectedCollection => + selectedCollection.id === startingCollection.id + ) + ) + + for (const collection of collectionsToAdd) { + await createAssistantCollection({ + user_id: item.user_id, + assistant_id: assistantId, + collection_id: collection.id + }) + } + + for (const collection of collectionsToRemove) { + await deleteAssistantCollection(assistantId, collection.id) + } + + const toolsToAdd = selectedAssistantTools.filter( + selectedTool => + !startingAssistantTools.some( + startingTool => startingTool.id === selectedTool.id + ) + ) + + const toolsToRemove = startingAssistantTools.filter(startingTool => + selectedAssistantTools.some( + selectedTool => selectedTool.id === startingTool.id + ) + ) + + for (const tool of toolsToAdd) { + await createAssistantTool({ + user_id: item.user_id, + assistant_id: assistantId, + tool_id: tool.id + }) + } + + for (const tool of toolsToRemove) { + await deleteAssistantTool(assistantId, tool.id) + } + + let updatedAssistant = await updateAssistant(assistantId, rest) + + if (image) { + const filePath = await uploadAssistantImage(updatedAssistant, image) + + updatedAssistant = await updateAssistant(assistantId, { + image_path: filePath + }) + + const url = (await getAssistantImageFromStorage(filePath)) || "" + + if (url) { + const response = await fetch(url) + const blob = await response.blob() + const base64 = await convertBlobToBase64(blob) + + setAssistantImages(prev => [ + ...prev, + { + assistantId: updatedAssistant.id, + path: filePath, + base64, + url + } + ]) + } + } + + await handleWorkspaceUpdates( + startingWorkspaces, + selectedWorkspaces, + assistantId, + deleteAssistantWorkspace, + createAssistantWorkspaces as any, + "assistant_id" + ) + + return updatedAssistant + }, + tools: async (toolId: string, updateState: TablesUpdate<"tools">) => { + const updatedTool = await updateTool(toolId, updateState) + + await handleWorkspaceUpdates( + startingWorkspaces, + selectedWorkspaces, + toolId, + deleteToolWorkspace, + createToolWorkspaces as any, + "tool_id" + ) + + return updatedTool + }, + models: async (modelId: string, updateState: TablesUpdate<"models">) => { + const updatedModel = await updateModel(modelId, updateState) + + await handleWorkspaceUpdates( + startingWorkspaces, + selectedWorkspaces, + modelId, + deleteModelWorkspace, + createModelWorkspaces as any, + "model_id" + ) + + return updatedModel + } + } + + const stateUpdateFunctions = { + chats: setChats, + presets: setPresets, + prompts: setPrompts, + files: setFiles, + collections: setCollections, + assistants: setAssistants, + tools: setTools, + models: setModels + } + + const handleUpdate = async () => { + try { + const updateFunction = updateFunctions[contentType] + const setStateFunction = stateUpdateFunctions[contentType] + + if (!updateFunction || !setStateFunction) return + if (isTyping) return // Prevent update while typing + + const updatedItem = await updateFunction(item.id, updateState) + + setStateFunction((prevItems: any) => + prevItems.map((prevItem: any) => + prevItem.id === item.id ? updatedItem : prevItem + ) + ) + + setIsOpen(false) + + toast.success(`${contentType.slice(0, -1)} updated successfully`) + } catch (error) { + toast.error(`Error updating ${contentType.slice(0, -1)}. ${error}`) + } + } + + const handleSelectWorkspace = (workspace: Tables<"workspaces">) => { + setSelectedWorkspaces(prevState => { + const isWorkspaceAlreadySelected = prevState.find( + selectedWorkspace => selectedWorkspace.id === workspace.id + ) + + if (isWorkspaceAlreadySelected) { + return prevState.filter( + selectedWorkspace => selectedWorkspace.id !== workspace.id + ) + } else { + return [...prevState, workspace] + } + }) + } + + const handleKeyDown = (e: React.KeyboardEvent) => { + if (!isTyping && e.key === "Enter" && !e.shiftKey) { + e.preventDefault() + buttonRef.current?.click() + } + } + + return ( + + {children} + + +
+ + + {t("side.edit")} {contentType.slice(0, -1)} + + + +
+ {workspaces.length > 1 && ( +
+ + + +
+ )} + + {renderInputs(renderState[contentType])} +
+
+ + + + +
+ + + +
+
+
+
+ ) +} diff --git a/chatdesk-ui/components/sidebar/items/assistants/assistant-item.tsx b/chatdesk-ui/components/sidebar/items/assistants/assistant-item.tsx new file mode 100644 index 0000000..823ac57 --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/assistants/assistant-item.tsx @@ -0,0 +1,307 @@ +import { ChatSettingsForm } from "@/components/ui/chat-settings-form" +import ImagePicker from "@/components/ui/image-picker" +import { Input } from "@/components/ui/input" +import { Label } from "@/components/ui/label" +import { ChatbotUIContext } from "@/context/context" +import { ASSISTANT_DESCRIPTION_MAX, ASSISTANT_NAME_MAX } from "@/db/limits" +import { Tables } from "@/supabase/types" +import { IconRobotFace } from "@tabler/icons-react" +import Image from "next/image" +import { FC, useContext, useEffect, useState } from "react" +import profile from "react-syntax-highlighter/dist/esm/languages/hljs/profile" +import { SidebarItem } from "../all/sidebar-display-item" +import { AssistantRetrievalSelect } from "./assistant-retrieval-select" +import { AssistantToolSelect } from "./assistant-tool-select" + +import { useTranslation } from 'react-i18next' + +interface AssistantItemProps { + assistant: Tables<"assistants"> +} + +export const AssistantItem: FC = ({ assistant }) => { + const { t } = useTranslation() + + const { selectedWorkspace, assistantImages } = useContext(ChatbotUIContext) + + const [name, setName] = useState(assistant.name) + const [isTyping, setIsTyping] = useState(false) + const [description, setDescription] = useState(assistant.description) + const [assistantChatSettings, setAssistantChatSettings] = useState({ + model: assistant.model, + prompt: assistant.prompt, + temperature: assistant.temperature, + contextLength: assistant.context_length, + includeProfileContext: assistant.include_profile_context, + includeWorkspaceInstructions: assistant.include_workspace_instructions + }) + const [selectedImage, setSelectedImage] = useState(null) + const [imageLink, setImageLink] = useState("") + + useEffect(() => { + const assistantImage = + assistantImages.find(image => image.path === assistant.image_path) + ?.base64 || "" + setImageLink(assistantImage) + }, [assistant, assistantImages]) + + const handleFileSelect = ( + file: Tables<"files">, + setSelectedAssistantFiles: React.Dispatch< + React.SetStateAction[]> + > + ) => { + setSelectedAssistantFiles(prevState => { + const isFileAlreadySelected = prevState.find( + selectedFile => selectedFile.id === file.id + ) + + if (isFileAlreadySelected) { + return prevState.filter(selectedFile => selectedFile.id !== file.id) + } else { + return [...prevState, file] + } + }) + } + + const handleCollectionSelect = ( + collection: Tables<"collections">, + setSelectedAssistantCollections: React.Dispatch< + React.SetStateAction[]> + > + ) => { + setSelectedAssistantCollections(prevState => { + const isCollectionAlreadySelected = prevState.find( + selectedCollection => selectedCollection.id === collection.id + ) + + if (isCollectionAlreadySelected) { + return prevState.filter( + selectedCollection => selectedCollection.id !== collection.id + ) + } else { + return [...prevState, collection] + } + }) + } + + const handleToolSelect = ( + tool: Tables<"tools">, + setSelectedAssistantTools: React.Dispatch< + React.SetStateAction[]> + > + ) => { + setSelectedAssistantTools(prevState => { + const isToolAlreadySelected = prevState.find( + selectedTool => selectedTool.id === tool.id + ) + + if (isToolAlreadySelected) { + return prevState.filter(selectedTool => selectedTool.id !== tool.id) + } else { + return [...prevState, tool] + } + }) + } + + if (!profile) return null + if (!selectedWorkspace) return null + + return ( + + ) : ( + + ) + } + updateState={{ + image: selectedImage, + user_id: assistant.user_id, + name, + description, + include_profile_context: assistantChatSettings.includeProfileContext, + include_workspace_instructions: + assistantChatSettings.includeWorkspaceInstructions, + context_length: assistantChatSettings.contextLength, + model: assistantChatSettings.model, + image_path: assistant.image_path, + prompt: assistantChatSettings.prompt, + temperature: assistantChatSettings.temperature + }} + renderInputs={(renderState: { + startingAssistantFiles: Tables<"files">[] + setStartingAssistantFiles: React.Dispatch< + React.SetStateAction[]> + > + selectedAssistantFiles: Tables<"files">[] + setSelectedAssistantFiles: React.Dispatch< + React.SetStateAction[]> + > + startingAssistantCollections: Tables<"collections">[] + setStartingAssistantCollections: React.Dispatch< + React.SetStateAction[]> + > + selectedAssistantCollections: Tables<"collections">[] + setSelectedAssistantCollections: React.Dispatch< + React.SetStateAction[]> + > + startingAssistantTools: Tables<"tools">[] + setStartingAssistantTools: React.Dispatch< + React.SetStateAction[]> + > + selectedAssistantTools: Tables<"tools">[] + setSelectedAssistantTools: React.Dispatch< + React.SetStateAction[]> + > + }) => ( + <> +
+ + + setName(e.target.value)} + maxLength={ASSISTANT_NAME_MAX} + /> +
+ +
+ + + setDescription(e.target.value)} + maxLength={ASSISTANT_DESCRIPTION_MAX} + /> +
+ +
+ + + +
+ + + +
+ + + + ![ + ...renderState.selectedAssistantFiles, + ...renderState.selectedAssistantCollections + ].some( + selectedFile => selectedFile.id === startingFile.id + ) + ), + ...renderState.selectedAssistantFiles.filter( + selectedFile => + !renderState.startingAssistantFiles.some( + startingFile => startingFile.id === selectedFile.id + ) + ), + ...renderState.startingAssistantCollections.filter( + startingCollection => + ![ + ...renderState.selectedAssistantFiles, + ...renderState.selectedAssistantCollections + ].some( + selectedCollection => + selectedCollection.id === startingCollection.id + ) + ), + ...renderState.selectedAssistantCollections.filter( + selectedCollection => + !renderState.startingAssistantCollections.some( + startingCollection => + startingCollection.id === selectedCollection.id + ) + ) + ] + } + onAssistantRetrievalItemsSelect={item => + "type" in item + ? handleFileSelect( + item, + renderState.setSelectedAssistantFiles + ) + : handleCollectionSelect( + item, + renderState.setSelectedAssistantCollections + ) + } + /> +
+ +
+ + + + !renderState.selectedAssistantTools.some( + selectedTool => selectedTool.id === startingTool.id + ) + ), + ...renderState.selectedAssistantTools.filter( + selectedTool => + !renderState.startingAssistantTools.some( + startingTool => startingTool.id === selectedTool.id + ) + ) + ] + } + onAssistantToolsSelect={tool => + handleToolSelect(tool, renderState.setSelectedAssistantTools) + } + /> +
+ + )} + /> + ) +} diff --git a/chatdesk-ui/components/sidebar/items/assistants/assistant-retrieval-select.tsx b/chatdesk-ui/components/sidebar/items/assistants/assistant-retrieval-select.tsx new file mode 100644 index 0000000..efcfea7 --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/assistants/assistant-retrieval-select.tsx @@ -0,0 +1,201 @@ +import { Button } from "@/components/ui/button" +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuTrigger +} from "@/components/ui/dropdown-menu" +import { Input } from "@/components/ui/input" +import { ChatbotUIContext } from "@/context/context" +import { Tables } from "@/supabase/types" +import { + IconBooks, + IconChevronDown, + IconCircleCheckFilled +} from "@tabler/icons-react" +import { FileIcon } from "lucide-react" +import { FC, useContext, useEffect, useRef, useState } from "react" + +import { useTranslation } from 'react-i18next' + +interface AssistantRetrievalSelectProps { + selectedAssistantRetrievalItems: Tables<"files">[] | Tables<"collections">[] + onAssistantRetrievalItemsSelect: ( + item: Tables<"files"> | Tables<"collections"> + ) => void +} + +export const AssistantRetrievalSelect: FC = ({ + selectedAssistantRetrievalItems, + onAssistantRetrievalItemsSelect +}) => { + const { t } = useTranslation() + + const { files, collections } = useContext(ChatbotUIContext) + + const inputRef = useRef(null) + const triggerRef = useRef(null) + + const [isOpen, setIsOpen] = useState(false) + const [search, setSearch] = useState("") + + useEffect(() => { + if (isOpen) { + setTimeout(() => { + inputRef.current?.focus() + }, 100) // FIX: hacky + } + }, [isOpen]) + + const handleItemSelect = (item: Tables<"files"> | Tables<"collections">) => { + onAssistantRetrievalItemsSelect(item) + } + + if (!files || !collections) return null + + return ( + { + setIsOpen(isOpen) + setSearch("") + }} + > + + + + + + setSearch(e.target.value)} + onKeyDown={e => e.stopPropagation()} + /> + + {selectedAssistantRetrievalItems + .filter(item => + item.name.toLowerCase().includes(search.toLowerCase()) + ) + .map(item => ( + | Tables<"collections">} + selected={selectedAssistantRetrievalItems.some( + selectedAssistantRetrieval => + selectedAssistantRetrieval.id === item.id + )} + onSelect={handleItemSelect} + /> + ))} + + {files + .filter( + file => + !selectedAssistantRetrievalItems.some( + selectedAssistantRetrieval => + selectedAssistantRetrieval.id === file.id + ) && file.name.toLowerCase().includes(search.toLowerCase()) + ) + .map(file => ( + + selectedAssistantRetrieval.id === file.id + )} + onSelect={handleItemSelect} + /> + ))} + + {collections + .filter( + collection => + !selectedAssistantRetrievalItems.some( + selectedAssistantRetrieval => + selectedAssistantRetrieval.id === collection.id + ) && collection.name.toLowerCase().includes(search.toLowerCase()) + ) + .map(collection => ( + + selectedAssistantRetrieval.id === collection.id + )} + onSelect={handleItemSelect} + /> + ))} + + + ) +} + +interface AssistantRetrievalOptionItemProps { + contentType: "files" | "collections" + item: Tables<"files"> | Tables<"collections"> + selected: boolean + onSelect: (item: Tables<"files"> | Tables<"collections">) => void +} + +const AssistantRetrievalItemOption: FC = ({ + contentType, + item, + selected, + onSelect +}) => { + const handleSelect = () => { + onSelect(item) + } + + return ( +
+
+ {contentType === "files" ? ( +
+ ).type} size={24} /> +
+ ) : ( +
+ +
+ )} + +
{item.name}
+
+ + {selected && ( + + )} +
+ ) +} diff --git a/chatdesk-ui/components/sidebar/items/assistants/assistant-tool-select.tsx b/chatdesk-ui/components/sidebar/items/assistants/assistant-tool-select.tsx new file mode 100644 index 0000000..73af4d5 --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/assistants/assistant-tool-select.tsx @@ -0,0 +1,165 @@ +import { Button } from "@/components/ui/button" +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuTrigger +} from "@/components/ui/dropdown-menu" +import { Input } from "@/components/ui/input" +import { ChatbotUIContext } from "@/context/context" +import { Tables } from "@/supabase/types" +import { + IconBolt, + IconChevronDown, + IconCircleCheckFilled +} from "@tabler/icons-react" +import { FC, useContext, useEffect, useRef, useState } from "react" + +import { useTranslation } from 'react-i18next' + +interface AssistantToolSelectProps { + selectedAssistantTools: Tables<"tools">[] + onAssistantToolsSelect: (tool: Tables<"tools">) => void +} + +export const AssistantToolSelect: FC = ({ + selectedAssistantTools, + onAssistantToolsSelect +}) => { + const { t } = useTranslation() + + const { tools } = useContext(ChatbotUIContext) + + const inputRef = useRef(null) + const triggerRef = useRef(null) + + const [isOpen, setIsOpen] = useState(false) + const [search, setSearch] = useState("") + + useEffect(() => { + if (isOpen) { + setTimeout(() => { + inputRef.current?.focus() + }, 100) // FIX: hacky + } + }, [isOpen]) + + const handleToolSelect = (tool: Tables<"tools">) => { + onAssistantToolsSelect(tool) + } + + if (!tools) return null + + return ( + { + setIsOpen(isOpen) + setSearch("") + }} + > + + + + + + setSearch(e.target.value)} + onKeyDown={e => e.stopPropagation()} + /> + + {selectedAssistantTools + .filter(item => + item.name.toLowerCase().includes(search.toLowerCase()) + ) + .map(tool => ( + + selectedAssistantRetrieval.id === tool.id + )} + onSelect={handleToolSelect} + /> + ))} + + {tools + .filter( + tool => + !selectedAssistantTools.some( + selectedAssistantRetrieval => + selectedAssistantRetrieval.id === tool.id + ) && tool.name.toLowerCase().includes(search.toLowerCase()) + ) + .map(tool => ( + + selectedAssistantRetrieval.id === tool.id + )} + onSelect={handleToolSelect} + /> + ))} + + + ) +} + +interface AssistantToolItemProps { + tool: Tables<"tools"> + selected: boolean + onSelect: (tool: Tables<"tools">) => void +} + +const AssistantToolItem: FC = ({ + tool, + selected, + onSelect +}) => { + const handleSelect = () => { + onSelect(tool) + } + + return ( +
+
+
+ +
+ +
{tool.name}
+
+ + {selected && ( + + )} +
+ ) +} diff --git a/chatdesk-ui/components/sidebar/items/assistants/create-assistant.tsx b/chatdesk-ui/components/sidebar/items/assistants/create-assistant.tsx new file mode 100644 index 0000000..6efe75e --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/assistants/create-assistant.tsx @@ -0,0 +1,216 @@ +import { SidebarCreateItem } from "@/components/sidebar/items/all/sidebar-create-item" +import { ChatSettingsForm } from "@/components/ui/chat-settings-form" +import ImagePicker from "@/components/ui/image-picker" +import { Input } from "@/components/ui/input" +import { Label } from "@/components/ui/label" +import { ChatbotUIContext } from "@/context/context" +import { ASSISTANT_DESCRIPTION_MAX, ASSISTANT_NAME_MAX } from "@/db/limits" +import { Tables, TablesInsert } from "@/supabase/types" +import { FC, useContext, useEffect, useState } from "react" +import { AssistantRetrievalSelect } from "./assistant-retrieval-select" +import { AssistantToolSelect } from "./assistant-tool-select" + +import { useTranslation } from 'react-i18next' + +interface CreateAssistantProps { + isOpen: boolean + onOpenChange: (isOpen: boolean) => void +} + +export const CreateAssistant: FC = ({ + isOpen, + onOpenChange +}) => { + + const { t } = useTranslation() + + const { profile, selectedWorkspace } = useContext(ChatbotUIContext) + + const [name, setName] = useState("") + const [isTyping, setIsTyping] = useState(false) + const [description, setDescription] = useState("") + const [assistantChatSettings, setAssistantChatSettings] = useState({ + model: selectedWorkspace?.default_model, + prompt: selectedWorkspace?.default_prompt, + temperature: selectedWorkspace?.default_temperature, + contextLength: selectedWorkspace?.default_context_length, + includeProfileContext: false, + includeWorkspaceInstructions: false, + embeddingsProvider: selectedWorkspace?.embeddings_provider + }) + const [selectedImage, setSelectedImage] = useState(null) + const [imageLink, setImageLink] = useState("") + const [selectedAssistantRetrievalItems, setSelectedAssistantRetrievalItems] = + useState[] | Tables<"collections">[]>([]) + const [selectedAssistantToolItems, setSelectedAssistantToolItems] = useState< + Tables<"tools">[] + >([]) + + useEffect(() => { + setAssistantChatSettings(prevSettings => { + const previousPrompt = prevSettings.prompt || "" + const previousPromptParts = previousPrompt.split(". ") + + previousPromptParts[0] = name ? `You are ${name}` : "" + + return { + ...prevSettings, + prompt: previousPromptParts.join(". ") + } + }) + }, [name]) + + const handleRetrievalItemSelect = ( + item: Tables<"files"> | Tables<"collections"> + ) => { + setSelectedAssistantRetrievalItems(prevState => { + const isItemAlreadySelected = prevState.find( + selectedItem => selectedItem.id === item.id + ) + + if (isItemAlreadySelected) { + return prevState.filter(selectedItem => selectedItem.id !== item.id) + } else { + return [...prevState, item] + } + }) + } + + const handleToolSelect = (item: Tables<"tools">) => { + setSelectedAssistantToolItems(prevState => { + const isItemAlreadySelected = prevState.find( + selectedItem => selectedItem.id === item.id + ) + + if (isItemAlreadySelected) { + return prevState.filter(selectedItem => selectedItem.id !== item.id) + } else { + return [...prevState, item] + } + }) + } + + const checkIfModelIsToolCompatible = () => { + if (!assistantChatSettings.model) return false + + const compatibleModels = [ + "gpt-4-turbo-preview", + "gpt-4-vision-preview", + "gpt-3.5-turbo-1106", + "gpt-4" + ] + const isModelCompatible = compatibleModels.includes( + assistantChatSettings.model + ) + + return isModelCompatible + } + + if (!profile) return null + if (!selectedWorkspace) return null + + return ( + + item.hasOwnProperty("type") + ) as Tables<"files">[], + collections: selectedAssistantRetrievalItems.filter( + item => !item.hasOwnProperty("type") + ) as Tables<"collections">[], + tools: selectedAssistantToolItems + } as TablesInsert<"assistants"> + } + isOpen={isOpen} + isTyping={isTyping} + renderInputs={() => ( + <> +
+ + + setName(e.target.value)} + maxLength={ASSISTANT_NAME_MAX} + /> +
+ +
+ + + setDescription(e.target.value)} + maxLength={ASSISTANT_DESCRIPTION_MAX} + /> +
+ +
+ + + +
+ + + +
+ + + +
+ + {checkIfModelIsToolCompatible() ? ( +
+ + + +
+ ) : ( +
+ {t("side.modelIncompatibleWithTools")} +
+ )} + + )} + onOpenChange={onOpenChange} + /> + ) +} diff --git a/chatdesk-ui/components/sidebar/items/chat/chat-item.tsx b/chatdesk-ui/components/sidebar/items/chat/chat-item.tsx new file mode 100644 index 0000000..d06db05 --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/chat/chat-item.tsx @@ -0,0 +1,122 @@ +import { ModelIcon } from "@/components/models/model-icon" +import { WithTooltip } from "@/components/ui/with-tooltip" +import { ChatbotUIContext } from "@/context/context" +import { LLM_LIST } from "@/lib/models/llm/llm-list" +import { cn } from "@/lib/utils" +import { Tables } from "@/supabase/types" +import { LLM } from "@/types" +import { IconRobotFace } from "@tabler/icons-react" +import Image from "next/image" +import { useParams, useRouter } from "next/navigation" +import { FC, useContext, useRef } from "react" +import { DeleteChat } from "./delete-chat" +import { UpdateChat } from "./update-chat" +import { usePathname } from "next/navigation" +import i18nConfig from "@/i18nConfig" + +interface ChatItemProps { + chat: Tables<"chats"> +} + +export const ChatItem: FC = ({ chat }) => { + const { + selectedWorkspace, + selectedChat, + availableLocalModels, + assistantImages, + availableOpenRouterModels + } = useContext(ChatbotUIContext) + + const pathname = usePathname() // 获取当前路径 + + const pathSegments = pathname.split("/").filter(Boolean) + const locales = i18nConfig.locales + const defaultLocale = i18nConfig.defaultLocale + + const segment = pathSegments[0] as (typeof locales)[number] + const pathLocale = locales.includes(segment) ? segment : null + const localePrefix = pathLocale && pathLocale !== defaultLocale ? `/${pathLocale}` : "" + + const router = useRouter() + const params = useParams() + const isActive = params.chatid === chat.id || selectedChat?.id === chat.id + + const itemRef = useRef(null) + + const handleClick = () => { + if (!selectedWorkspace) return + return router.push(`${localePrefix}/${selectedWorkspace.id}/chat/${chat.id}`) + } + + const handleKeyDown = (e: React.KeyboardEvent) => { + if (e.key === "Enter") { + e.stopPropagation() + itemRef.current?.click() + } + } + + const MODEL_DATA = [ + ...LLM_LIST, + ...availableLocalModels, + ...availableOpenRouterModels + ].find(llm => llm.modelId === chat.model) as LLM + + const assistantImage = assistantImages.find( + image => image.assistantId === chat.assistant_id + )?.base64 + + return ( +
+ {chat.assistant_id ? ( + assistantImage ? ( + Assistant image + ) : ( + + ) + ) : ( + {MODEL_DATA?.modelName}
} + trigger={ + + } + /> + )} + +
+ {chat.name} +
+ +
{ + e.stopPropagation() + e.preventDefault() + }} + className={`ml-2 flex space-x-2 ${!isActive && "w-11 opacity-0 group-hover:opacity-100"}`} + > + + + +
+ + ) +} diff --git a/chatdesk-ui/components/sidebar/items/chat/delete-chat.tsx b/chatdesk-ui/components/sidebar/items/chat/delete-chat.tsx new file mode 100644 index 0000000..162cf73 --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/chat/delete-chat.tsx @@ -0,0 +1,85 @@ +import { useChatHandler } from "@/components/chat/chat-hooks/use-chat-handler" +import { Button } from "@/components/ui/button" +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, + DialogTrigger +} from "@/components/ui/dialog" +import { ChatbotUIContext } from "@/context/context" +import { deleteChat } from "@/db/chats" +import useHotkey from "@/lib/hooks/use-hotkey" +import { Tables } from "@/supabase/types" +import { IconTrash } from "@tabler/icons-react" +import { FC, useContext, useRef, useState } from "react" + +import { useTranslation } from 'react-i18next' + +interface DeleteChatProps { + chat: Tables<"chats"> +} + +export const DeleteChat: FC = ({ chat }) => { + + const { t } = useTranslation() + + useHotkey("Backspace", () => setShowChatDialog(true)) + + const { setChats } = useContext(ChatbotUIContext) + const { handleNewChat } = useChatHandler() + + const buttonRef = useRef(null) + + const [showChatDialog, setShowChatDialog] = useState(false) + + const handleDeleteChat = async () => { + await deleteChat(chat.id) + + setChats(prevState => prevState.filter(c => c.id !== chat.id)) + + setShowChatDialog(false) + + handleNewChat() + } + + const handleKeyDown = (e: React.KeyboardEvent) => { + if (e.key === "Enter") { + buttonRef.current?.click() + } + } + + return ( + + + + + + + + {t("side.deleteChatTitle")} {chat.name} + + + {t("side.deleteChatConfirm")} + + + + + + + + + + + ) +} diff --git a/chatdesk-ui/components/sidebar/items/chat/update-chat.tsx b/chatdesk-ui/components/sidebar/items/chat/update-chat.tsx new file mode 100644 index 0000000..0ef2687 --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/chat/update-chat.tsx @@ -0,0 +1,81 @@ +import { Button } from "@/components/ui/button" +import { + Dialog, + DialogContent, + DialogFooter, + DialogHeader, + DialogTitle, + DialogTrigger +} from "@/components/ui/dialog" +import { Input } from "@/components/ui/input" +import { Label } from "@/components/ui/label" +import { ChatbotUIContext } from "@/context/context" +import { updateChat } from "@/db/chats" +import { Tables } from "@/supabase/types" +import { IconEdit } from "@tabler/icons-react" +import { FC, useContext, useRef, useState } from "react" + +import { useTranslation } from 'react-i18next' + +interface UpdateChatProps { + chat: Tables<"chats"> +} + +export const UpdateChat: FC = ({ chat }) => { + + const { t } = useTranslation() + + const { setChats } = useContext(ChatbotUIContext) + + const buttonRef = useRef(null) + + const [showChatDialog, setShowChatDialog] = useState(false) + const [name, setName] = useState(chat.name) + + const handleUpdateChat = async (e: React.MouseEvent) => { + const updatedChat = await updateChat(chat.id, { + name + }) + setChats(prevState => + prevState.map(c => (c.id === chat.id ? updatedChat : c)) + ) + + setShowChatDialog(false) + } + + const handleKeyDown = (e: React.KeyboardEvent) => { + if (e.key === "Enter") { + buttonRef.current?.click() + } + } + + return ( + + + + + + + + {t("side.editChat")} + + +
+ + + setName(e.target.value)} /> +
+ + + + + + +
+
+ ) +} diff --git a/chatdesk-ui/components/sidebar/items/collections/collection-file-select.tsx b/chatdesk-ui/components/sidebar/items/collections/collection-file-select.tsx new file mode 100644 index 0000000..2c1f96f --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/collections/collection-file-select.tsx @@ -0,0 +1,159 @@ +import { Button } from "@/components/ui/button" +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuTrigger +} from "@/components/ui/dropdown-menu" +import { FileIcon } from "@/components/ui/file-icon" +import { Input } from "@/components/ui/input" +import { ChatbotUIContext } from "@/context/context" +import { CollectionFile } from "@/types" +import { IconChevronDown, IconCircleCheckFilled } from "@tabler/icons-react" +import { FC, useContext, useEffect, useRef, useState } from "react" + +import { useTranslation } from 'react-i18next' + +interface CollectionFileSelectProps { + selectedCollectionFiles: CollectionFile[] + onCollectionFileSelect: (file: CollectionFile) => void +} + +export const CollectionFileSelect: FC = ({ + selectedCollectionFiles, + onCollectionFileSelect +}) => { + const { t } = useTranslation() + + const { files } = useContext(ChatbotUIContext) + + const inputRef = useRef(null) + const triggerRef = useRef(null) + + const [isOpen, setIsOpen] = useState(false) + const [search, setSearch] = useState("") + + useEffect(() => { + if (isOpen) { + setTimeout(() => { + inputRef.current?.focus() + }, 100) // FIX: hacky + } + }, [isOpen]) + + const handleFileSelect = (file: CollectionFile) => { + onCollectionFileSelect(file) + } + + if (!files) return null + + return ( + { + setIsOpen(isOpen) + setSearch("") + }} + > + + + + + + setSearch(e.target.value)} + onKeyDown={e => e.stopPropagation()} + /> + + {selectedCollectionFiles + .filter(file => + file.name.toLowerCase().includes(search.toLowerCase()) + ) + .map(file => ( + selectedCollectionFile.id === file.id + )} + onSelect={handleFileSelect} + /> + ))} + + {files + .filter( + file => + !selectedCollectionFiles.some( + selectedCollectionFile => selectedCollectionFile.id === file.id + ) && file.name.toLowerCase().includes(search.toLowerCase()) + ) + .map(file => ( + selectedCollectionFile.id === file.id + )} + onSelect={handleFileSelect} + /> + ))} + + + ) +} + +interface CollectionFileItemProps { + file: CollectionFile + selected: boolean + onSelect: (file: CollectionFile) => void +} + +const CollectionFileItem: FC = ({ + file, + selected, + onSelect +}) => { + const handleSelect = () => { + onSelect(file) + } + + return ( +
+
+
+ +
+ +
{file.name}
+
+ + {selected && ( + + )} +
+ ) +} diff --git a/chatdesk-ui/components/sidebar/items/collections/collection-item.tsx b/chatdesk-ui/components/sidebar/items/collections/collection-item.tsx new file mode 100644 index 0000000..a9bd861 --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/collections/collection-item.tsx @@ -0,0 +1,122 @@ +import { Input } from "@/components/ui/input" +import { Label } from "@/components/ui/label" +import { COLLECTION_DESCRIPTION_MAX, COLLECTION_NAME_MAX } from "@/db/limits" +import { Tables } from "@/supabase/types" +import { CollectionFile } from "@/types" +import { IconBooks } from "@tabler/icons-react" +import { FC, useState } from "react" +import { SidebarItem } from "../all/sidebar-display-item" +import { CollectionFileSelect } from "./collection-file-select" + +import { useTranslation } from 'react-i18next' + +interface CollectionItemProps { + collection: Tables<"collections"> +} + +export const CollectionItem: FC = ({ collection }) => { + + const { t } = useTranslation() + + const [name, setName] = useState(collection.name) + const [isTyping, setIsTyping] = useState(false) + const [description, setDescription] = useState(collection.description) + + const handleFileSelect = ( + file: CollectionFile, + setSelectedCollectionFiles: React.Dispatch< + React.SetStateAction + > + ) => { + setSelectedCollectionFiles(prevState => { + const isFileAlreadySelected = prevState.find( + selectedFile => selectedFile.id === file.id + ) + + if (isFileAlreadySelected) { + return prevState.filter(selectedFile => selectedFile.id !== file.id) + } else { + return [...prevState, file] + } + }) + } + + return ( + } + updateState={{ + name, + description + }} + renderInputs={(renderState: { + startingCollectionFiles: CollectionFile[] + setStartingCollectionFiles: React.Dispatch< + React.SetStateAction + > + selectedCollectionFiles: CollectionFile[] + setSelectedCollectionFiles: React.Dispatch< + React.SetStateAction + > + }) => { + return ( + <> +
+ + + + !renderState.selectedCollectionFiles.some( + selectedFile => + selectedFile.id === startingFile.id + ) + ), + ...renderState.selectedCollectionFiles.filter( + selectedFile => + !renderState.startingCollectionFiles.some( + startingFile => + startingFile.id === selectedFile.id + ) + ) + ] + } + onCollectionFileSelect={file => + handleFileSelect(file, renderState.setSelectedCollectionFiles) + } + /> +
+ +
+ + + setName(e.target.value)} + maxLength={COLLECTION_NAME_MAX} + /> +
+ +
+ + + setDescription(e.target.value)} + maxLength={COLLECTION_DESCRIPTION_MAX} + /> +
+ + ) + }} + /> + ) +} diff --git a/chatdesk-ui/components/sidebar/items/collections/create-collection.tsx b/chatdesk-ui/components/sidebar/items/collections/create-collection.tsx new file mode 100644 index 0000000..d4056cc --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/collections/create-collection.tsx @@ -0,0 +1,104 @@ +import { SidebarCreateItem } from "@/components/sidebar/items/all/sidebar-create-item" +import { Input } from "@/components/ui/input" +import { Label } from "@/components/ui/label" +import { ChatbotUIContext } from "@/context/context" +import { COLLECTION_DESCRIPTION_MAX, COLLECTION_NAME_MAX } from "@/db/limits" +import { TablesInsert } from "@/supabase/types" +import { CollectionFile } from "@/types" +import { FC, useContext, useState } from "react" +import { CollectionFileSelect } from "./collection-file-select" + +import { useTranslation } from 'react-i18next' + +interface CreateCollectionProps { + isOpen: boolean + onOpenChange: (isOpen: boolean) => void +} + +export const CreateCollection: FC = ({ + isOpen, + onOpenChange +}) => { + const { t } = useTranslation() + + const { profile, selectedWorkspace } = useContext(ChatbotUIContext) + + const [name, setName] = useState("") + const [isTyping, setIsTyping] = useState(false) + const [description, setDescription] = useState("") + const [selectedCollectionFiles, setSelectedCollectionFiles] = useState< + CollectionFile[] + >([]) + + const handleFileSelect = (file: CollectionFile) => { + setSelectedCollectionFiles(prevState => { + const isFileAlreadySelected = prevState.find( + selectedFile => selectedFile.id === file.id + ) + + if (isFileAlreadySelected) { + return prevState.filter(selectedFile => selectedFile.id !== file.id) + } else { + return [...prevState, file] + } + }) + } + + if (!profile) return null + if (!selectedWorkspace) return null + + return ( + ({ + user_id: profile.user_id, + collection_id: "", + file_id: file.id + })), + user_id: profile.user_id, + name, + description + } as TablesInsert<"collections"> + } + isOpen={isOpen} + isTyping={isTyping} + onOpenChange={onOpenChange} + renderInputs={() => ( + <> +
+ + + +
+ +
+ + + setName(e.target.value)} + maxLength={COLLECTION_NAME_MAX} + /> +
+ +
+ + + setDescription(e.target.value)} + maxLength={COLLECTION_DESCRIPTION_MAX} + /> +
+ + )} + /> + ) +} diff --git a/chatdesk-ui/components/sidebar/items/files/create-file.tsx b/chatdesk-ui/components/sidebar/items/files/create-file.tsx new file mode 100644 index 0000000..3a9a69f --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/files/create-file.tsx @@ -0,0 +1,97 @@ +import { ACCEPTED_FILE_TYPES } from "@/components/chat/chat-hooks/use-select-file-handler" +import { SidebarCreateItem } from "@/components/sidebar/items/all/sidebar-create-item" +import { Input } from "@/components/ui/input" +import { Label } from "@/components/ui/label" +import { ChatbotUIContext } from "@/context/context" +import { FILE_DESCRIPTION_MAX, FILE_NAME_MAX } from "@/db/limits" +import { TablesInsert } from "@/supabase/types" +import { FC, useContext, useState } from "react" + +import { useTranslation } from 'react-i18next' + +interface CreateFileProps { + isOpen: boolean + onOpenChange: (isOpen: boolean) => void +} + +export const CreateFile: FC = ({ isOpen, onOpenChange }) => { + const { t } = useTranslation() + + const { profile, selectedWorkspace } = useContext(ChatbotUIContext) + + const [name, setName] = useState("") + const [isTyping, setIsTyping] = useState(false) + const [description, setDescription] = useState("") + const [selectedFile, setSelectedFile] = useState(null) + + const handleSelectedFile = async (e: React.ChangeEvent) => { + if (!e.target.files) return + + const file = e.target.files[0] + + if (!file) return + + setSelectedFile(file) + const fileNameWithoutExtension = file.name.split(".").slice(0, -1).join(".") + setName(fileNameWithoutExtension) + } + + if (!profile) return null + if (!selectedWorkspace) return null + + return ( + + } + isOpen={isOpen} + isTyping={isTyping} + onOpenChange={onOpenChange} + renderInputs={() => ( + <> +
+ + + +
+ +
+ + + setName(e.target.value)} + maxLength={FILE_NAME_MAX} + /> +
+ +
+ + + setDescription(e.target.value)} + maxLength={FILE_DESCRIPTION_MAX} + /> +
+ + )} + /> + ) +} diff --git a/chatdesk-ui/components/sidebar/items/files/file-item.tsx b/chatdesk-ui/components/sidebar/items/files/file-item.tsx new file mode 100644 index 0000000..10758a3 --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/files/file-item.tsx @@ -0,0 +1,97 @@ +import { FileIcon } from "@/components/ui/file-icon" +import { Input } from "@/components/ui/input" +import { Label } from "@/components/ui/label" +import { FILE_DESCRIPTION_MAX, FILE_NAME_MAX } from "@/db/limits" +import { getFileFromStorage } from "@/db/storage/files" +import { Tables } from "@/supabase/types" +import { FC, useState } from "react" +import { SidebarItem } from "../all/sidebar-display-item" +import { useTranslation } from 'react-i18next' + +interface FileItemProps { + file: Tables<"files"> +} + +export const FileItem: FC = ({ file }) => { + const { t } = useTranslation() + const [name, setName] = useState(file.name) + const [isTyping, setIsTyping] = useState(false) + const [description, setDescription] = useState(file.description) + + const getLinkAndView = async () => { + const link = await getFileFromStorage(file.file_path) + window.open(link, "_blank") + } + + return ( + } + updateState={{ name, description }} + renderInputs={() => ( + <> +
+ {t("side.view")} {file.name} +
+ +
+
{file.type}
+ +
{formatFileSize(file.size)}
+ +
{file.tokens.toLocaleString()} tokens
+
+ +
+ + + setName(e.target.value)} + maxLength={FILE_NAME_MAX} + /> +
+ +
+ + + setDescription(e.target.value)} + maxLength={FILE_DESCRIPTION_MAX} + /> +
+ + )} + /> + ) +} + +export const formatFileSize = (sizeInBytes: number): string => { + let size = sizeInBytes + let unit = "bytes" + + if (size >= 1024) { + size /= 1024 + unit = "KB" + } + + if (size >= 1024) { + size /= 1024 + unit = "MB" + } + + if (size >= 1024) { + size /= 1024 + unit = "GB" + } + + return `${size.toFixed(2)} ${unit}` +} diff --git a/chatdesk-ui/components/sidebar/items/folders/delete-folder.tsx b/chatdesk-ui/components/sidebar/items/folders/delete-folder.tsx new file mode 100644 index 0000000..1de9044 --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/folders/delete-folder.tsx @@ -0,0 +1,143 @@ +import { Button } from "@/components/ui/button" +import { + Dialog, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogTitle, + DialogTrigger +} from "@/components/ui/dialog" +import { ChatbotUIContext } from "@/context/context" +import { deleteFolder } from "@/db/folders" +import { supabase } from "@/lib/supabase/browser-client" +import { Tables } from "@/supabase/types" +import { ContentType } from "@/types" +import { IconTrash } from "@tabler/icons-react" +import { FC, useContext, useRef, useState } from "react" +import { toast } from "sonner" +import { useTranslation } from 'react-i18next' + +interface DeleteFolderProps { + folder: Tables<"folders"> + contentType: ContentType +} + +export const DeleteFolder: FC = ({ + folder, + contentType +}) => { + const { t } = useTranslation() + const { + setChats, + setFolders, + setPresets, + setPrompts, + setFiles, + setCollections, + setAssistants, + setTools, + setModels + } = useContext(ChatbotUIContext) + + const buttonRef = useRef(null) + + const [showFolderDialog, setShowFolderDialog] = useState(false) + + const stateUpdateFunctions = { + chats: setChats, + presets: setPresets, + prompts: setPrompts, + files: setFiles, + collections: setCollections, + assistants: setAssistants, + tools: setTools, + models: setModels + } + + const handleDeleteFolderOnly = async () => { + await deleteFolder(folder.id) + + setFolders(prevState => prevState.filter(c => c.id !== folder.id)) + + setShowFolderDialog(false) + + const setStateFunction = stateUpdateFunctions[contentType] + + if (!setStateFunction) return + + setStateFunction((prevItems: any) => + prevItems.map((item: any) => { + if (item.folder_id === folder.id) { + return { + ...item, + folder_id: null + } + } + + return item + }) + ) + } + + const handleDeleteFolderAndItems = async () => { + const setStateFunction = stateUpdateFunctions[contentType] + + if (!setStateFunction) return + + const { error } = await supabase + .from(contentType) + .delete() + .eq("folder_id", folder.id) + + if (error) { + toast.error(error.message) + } + + setStateFunction((prevItems: any) => + prevItems.filter((item: any) => item.folder_id !== folder.id) + ) + + handleDeleteFolderOnly() + } + + return ( + + + + + + + + {t("side.delete")} {folder.name} + + + {t("side.confirmDeleteFolder")} + + + + + + + + + + + + + ) +} diff --git a/chatdesk-ui/components/sidebar/items/folders/folder-item.tsx b/chatdesk-ui/components/sidebar/items/folders/folder-item.tsx new file mode 100644 index 0000000..6768f21 --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/folders/folder-item.tsx @@ -0,0 +1,114 @@ +import { cn } from "@/lib/utils" +import { Tables } from "@/supabase/types" +import { ContentType } from "@/types" +import { IconChevronDown, IconChevronRight } from "@tabler/icons-react" +import { FC, useRef, useState } from "react" +import { DeleteFolder } from "./delete-folder" +import { UpdateFolder } from "./update-folder" + +interface FolderProps { + folder: Tables<"folders"> + contentType: ContentType + children: React.ReactNode + onUpdateFolder: (itemId: string, folderId: string | null) => void +} + +export const Folder: FC = ({ + folder, + contentType, + children, + onUpdateFolder +}) => { + const itemRef = useRef(null) + + const [isDragOver, setIsDragOver] = useState(false) + const [isExpanded, setIsExpanded] = useState(false) + const [isHovering, setIsHovering] = useState(false) + + const handleDragEnter = (e: React.DragEvent) => { + e.preventDefault() + setIsDragOver(true) + } + + const handleDragLeave = (e: React.DragEvent) => { + e.preventDefault() + setIsDragOver(false) + } + + const handleDragOver = (e: React.DragEvent) => { + e.preventDefault() + setIsDragOver(true) + } + + const handleDrop = (e: React.DragEvent) => { + e.preventDefault() + + setIsDragOver(false) + const itemId = e.dataTransfer.getData("text/plain") + onUpdateFolder(itemId, folder.id) + } + + const handleKeyDown = (e: React.KeyboardEvent) => { + if (e.key === "Enter") { + e.stopPropagation() + itemRef.current?.click() + } + } + + const handleClick = (e: React.MouseEvent) => { + setIsExpanded(!isExpanded) + } + + return ( +
setIsHovering(true)} + onMouseLeave={() => setIsHovering(false)} + > +
+
+
+ {isExpanded ? ( + + ) : ( + + )} + +
{folder.name}
+
+ + {isHovering && ( +
{ + e.stopPropagation() + e.preventDefault() + }} + className="ml-2 flex space-x-2" + > + + + +
+ )} +
+
+ + {isExpanded && ( +
{children}
+ )} +
+ ) +} diff --git a/chatdesk-ui/components/sidebar/items/folders/update-folder.tsx b/chatdesk-ui/components/sidebar/items/folders/update-folder.tsx new file mode 100644 index 0000000..067169d --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/folders/update-folder.tsx @@ -0,0 +1,78 @@ +import { Button } from "@/components/ui/button" +import { + Dialog, + DialogContent, + DialogFooter, + DialogHeader, + DialogTitle, + DialogTrigger +} from "@/components/ui/dialog" +import { Input } from "@/components/ui/input" +import { Label } from "@/components/ui/label" +import { ChatbotUIContext } from "@/context/context" +import { updateFolder } from "@/db/folders" +import { Tables } from "@/supabase/types" +import { IconEdit } from "@tabler/icons-react" +import { FC, useContext, useRef, useState } from "react" +import { useTranslation } from 'react-i18next' + +interface UpdateFolderProps { + folder: Tables<"folders"> +} + +export const UpdateFolder: FC = ({ folder }) => { + const { t } = useTranslation() + const { setFolders } = useContext(ChatbotUIContext) + + const buttonRef = useRef(null) + + const [showFolderDialog, setShowFolderDialog] = useState(false) + const [name, setName] = useState(folder.name) + + const handleUpdateFolder = async (e: React.MouseEvent) => { + const updatedFolder = await updateFolder(folder.id, { + name + }) + setFolders(prevState => + prevState.map(c => (c.id === folder.id ? updatedFolder : c)) + ) + + setShowFolderDialog(false) + } + + const handleKeyDown = (e: React.KeyboardEvent) => { + if (e.key === "Enter") { + buttonRef.current?.click() + } + } + + return ( + + + + + + + + {t("side.editFolder")} + + +
+ + + setName(e.target.value)} /> +
+ + + + + + +
+
+ ) +} diff --git a/chatdesk-ui/components/sidebar/items/models/create-model.tsx b/chatdesk-ui/components/sidebar/items/models/create-model.tsx new file mode 100644 index 0000000..2c80c69 --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/models/create-model.tsx @@ -0,0 +1,121 @@ +import { SidebarCreateItem } from "@/components/sidebar/items/all/sidebar-create-item" +import { Input } from "@/components/ui/input" +import { Label } from "@/components/ui/label" +import { ChatbotUIContext } from "@/context/context" +import { MODEL_NAME_MAX } from "@/db/limits" +import { TablesInsert } from "@/supabase/types" +import { FC, useContext, useState } from "react" +import { useTranslation, Trans } from 'react-i18next' + +interface CreateModelProps { + isOpen: boolean + onOpenChange: (isOpen: boolean) => void +} + +export const CreateModel: FC = ({ isOpen, onOpenChange }) => { + const { t } = useTranslation() + const { profile, selectedWorkspace } = useContext(ChatbotUIContext) + + const [isTyping, setIsTyping] = useState(false) + + const [apiKey, setApiKey] = useState("") + const [baseUrl, setBaseUrl] = useState("") + const [description, setDescription] = useState("") + const [modelId, setModelId] = useState("") + const [name, setName] = useState("") + const [contextLength, setContextLength] = useState(4096) + + if (!profile || !selectedWorkspace) return null + + return ( + + } + renderInputs={() => ( + <> +
+
{t("side.createCustomModel")}
+ +
+ {/* Your API *must* be compatible + with the OpenAI SDK. */} + }} /> +
+
+ +
+ + + setName(e.target.value)} + maxLength={MODEL_NAME_MAX} + /> +
+ +
+ + + setModelId(e.target.value)} + /> +
+ +
+ + + setBaseUrl(e.target.value)} + /> + +
+ {t("side.apiCompatibilityNotice")} + +
+
+ +
+ + + setApiKey(e.target.value)} + /> +
+ +
+ + + setContextLength(parseInt(e.target.value))} + /> +
+ + )} + /> + ) +} diff --git a/chatdesk-ui/components/sidebar/items/models/model-item.tsx b/chatdesk-ui/components/sidebar/items/models/model-item.tsx new file mode 100644 index 0000000..1bc8c48 --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/models/model-item.tsx @@ -0,0 +1,106 @@ +import { Input } from "@/components/ui/input" +import { Label } from "@/components/ui/label" +import { MODEL_NAME_MAX } from "@/db/limits" +import { Tables, TablesUpdate } from "@/supabase/types" +import { IconSparkles } from "@tabler/icons-react" +import { FC, useState } from "react" +import { SidebarItem } from "../all/sidebar-display-item" + +import { useTranslation } from 'react-i18next' + +interface ModelItemProps { + model: Tables<"models"> +} + +export const ModelItem: FC = ({ model }) => { + const { t } = useTranslation() + + const [isTyping, setIsTyping] = useState(false) + + const [apiKey, setApiKey] = useState(model.api_key) + const [baseUrl, setBaseUrl] = useState(model.base_url) + const [description, setDescription] = useState(model.description) + const [modelId, setModelId] = useState(model.model_id) + const [name, setName] = useState(model.name) + const [contextLength, setContextLength] = useState(model.context_length) + + return ( + } + updateState={ + { + api_key: apiKey, + base_url: baseUrl, + description, + context_length: contextLength, + model_id: modelId, + name + } as TablesUpdate<"models"> + } + renderInputs={() => ( + <> +
+ + + setName(e.target.value)} + maxLength={MODEL_NAME_MAX} + /> +
+ +
+ + + setModelId(e.target.value)} + /> +
+ +
+ + + setBaseUrl(e.target.value)} + /> + +
+ {t("side.apiCompatibilityNotice")} +
+
+ +
+ + + setApiKey(e.target.value)} + /> +
+ +
+ + + setContextLength(parseInt(e.target.value))} + /> +
+ + )} + /> + ) +} diff --git a/chatdesk-ui/components/sidebar/items/presets/create-preset.tsx b/chatdesk-ui/components/sidebar/items/presets/create-preset.tsx new file mode 100644 index 0000000..e15431c --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/presets/create-preset.tsx @@ -0,0 +1,84 @@ +import { SidebarCreateItem } from "@/components/sidebar/items/all/sidebar-create-item" +import { ChatSettingsForm } from "@/components/ui/chat-settings-form" +import { Input } from "@/components/ui/input" +import { Label } from "@/components/ui/label" +import { ChatbotUIContext } from "@/context/context" +import { PRESET_NAME_MAX } from "@/db/limits" +import { TablesInsert } from "@/supabase/types" +import { FC, useContext, useState } from "react" + +import { useTranslation } from 'react-i18next' + +interface CreatePresetProps { + isOpen: boolean + onOpenChange: (isOpen: boolean) => void +} + +export const CreatePreset: FC = ({ + isOpen, + onOpenChange +}) => { + const { t } = useTranslation() + const { profile, selectedWorkspace } = useContext(ChatbotUIContext) + + const [name, setName] = useState("") + const [isTyping, setIsTyping] = useState(false) + const [description, setDescription] = useState("") + const [presetChatSettings, setPresetChatSettings] = useState({ + model: selectedWorkspace?.default_model, + prompt: selectedWorkspace?.default_prompt, + temperature: selectedWorkspace?.default_temperature, + contextLength: selectedWorkspace?.default_context_length, + includeProfileContext: selectedWorkspace?.include_profile_context, + includeWorkspaceInstructions: + selectedWorkspace?.include_workspace_instructions, + embeddingsProvider: selectedWorkspace?.embeddings_provider + }) + + if (!profile) return null + if (!selectedWorkspace) return null + + return ( + + } + renderInputs={() => ( + <> +
+ + + setName(e.target.value)} + maxLength={PRESET_NAME_MAX} + /> +
+ + + + )} + /> + ) +} diff --git a/chatdesk-ui/components/sidebar/items/presets/preset-item.tsx b/chatdesk-ui/components/sidebar/items/presets/preset-item.tsx new file mode 100644 index 0000000..3544b9f --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/presets/preset-item.tsx @@ -0,0 +1,78 @@ +import { ModelIcon } from "@/components/models/model-icon" +import { ChatSettingsForm } from "@/components/ui/chat-settings-form" +import { Input } from "@/components/ui/input" +import { Label } from "@/components/ui/label" +import { PRESET_NAME_MAX } from "@/db/limits" +import { LLM_LIST } from "@/lib/models/llm/llm-list" +import { Tables } from "@/supabase/types" +import { FC, useState } from "react" +import { SidebarItem } from "../all/sidebar-display-item" + +import { useTranslation } from 'react-i18next' + +interface PresetItemProps { + preset: Tables<"presets"> +} + +export const PresetItem: FC = ({ preset }) => { + const { t } = useTranslation() + const [name, setName] = useState(preset.name) + const [isTyping, setIsTyping] = useState(false) + const [description, setDescription] = useState(preset.description) + const [presetChatSettings, setPresetChatSettings] = useState({ + model: preset.model, + prompt: preset.prompt, + temperature: preset.temperature, + contextLength: preset.context_length, + includeProfileContext: preset.include_profile_context, + includeWorkspaceInstructions: preset.include_workspace_instructions + }) + + const modelDetails = LLM_LIST.find(model => model.modelId === preset.model) + + return ( + + } + updateState={{ + name, + description, + include_profile_context: presetChatSettings.includeProfileContext, + include_workspace_instructions: + presetChatSettings.includeWorkspaceInstructions, + context_length: presetChatSettings.contextLength, + model: presetChatSettings.model, + prompt: presetChatSettings.prompt, + temperature: presetChatSettings.temperature + }} + renderInputs={() => ( + <> +
+ + + setName(e.target.value)} + maxLength={PRESET_NAME_MAX} + /> +
+ + + + )} + /> + ) +} diff --git a/chatdesk-ui/components/sidebar/items/prompts/create-prompt.tsx b/chatdesk-ui/components/sidebar/items/prompts/create-prompt.tsx new file mode 100644 index 0000000..1c01d87 --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/prompts/create-prompt.tsx @@ -0,0 +1,74 @@ +import { SidebarCreateItem } from "@/components/sidebar/items/all/sidebar-create-item" +import { Input } from "@/components/ui/input" +import { Label } from "@/components/ui/label" +import { TextareaAutosize } from "@/components/ui/textarea-autosize" +import { ChatbotUIContext } from "@/context/context" +import { PROMPT_NAME_MAX } from "@/db/limits" +import { TablesInsert } from "@/supabase/types" +import { FC, useContext, useState } from "react" +import { useTranslation } from 'react-i18next' + +interface CreatePromptProps { + isOpen: boolean + onOpenChange: (isOpen: boolean) => void +} + +export const CreatePrompt: FC = ({ + isOpen, + onOpenChange +}) => { + const { t } = useTranslation() + const { profile, selectedWorkspace } = useContext(ChatbotUIContext) + const [isTyping, setIsTyping] = useState(false) + const [name, setName] = useState("") + const [content, setContent] = useState("") + + if (!profile) return null + if (!selectedWorkspace) return null + + return ( + + } + renderInputs={() => ( + <> +
+ + + setName(e.target.value)} + maxLength={PROMPT_NAME_MAX} + onCompositionStart={() => setIsTyping(true)} + onCompositionEnd={() => setIsTyping(false)} + /> +
+ +
+ + + setIsTyping(true)} + onCompositionEnd={() => setIsTyping(false)} + /> +
+ + )} + /> + ) +} diff --git a/chatdesk-ui/components/sidebar/items/prompts/prompt-item.tsx b/chatdesk-ui/components/sidebar/items/prompts/prompt-item.tsx new file mode 100644 index 0000000..39fdea9 --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/prompts/prompt-item.tsx @@ -0,0 +1,60 @@ +import { Input } from "@/components/ui/input" +import { Label } from "@/components/ui/label" +import { TextareaAutosize } from "@/components/ui/textarea-autosize" +import { PROMPT_NAME_MAX } from "@/db/limits" +import { Tables } from "@/supabase/types" +import { IconPencil } from "@tabler/icons-react" +import { FC, useState } from "react" +import { SidebarItem } from "../all/sidebar-display-item" +import { useTranslation } from 'react-i18next' + +interface PromptItemProps { + prompt: Tables<"prompts"> +} + +export const PromptItem: FC = ({ prompt }) => { + const { t } = useTranslation() + + const [name, setName] = useState(prompt.name) + const [content, setContent] = useState(prompt.content) + const [isTyping, setIsTyping] = useState(false) + return ( + } + updateState={{ name, content }} + renderInputs={() => ( + <> +
+ + + setName(e.target.value)} + maxLength={PROMPT_NAME_MAX} + onCompositionStart={() => setIsTyping(true)} + onCompositionEnd={() => setIsTyping(false)} + /> +
+ +
+ + + setIsTyping(true)} + onCompositionEnd={() => setIsTyping(false)} + /> +
+ + )} + /> + ) +} diff --git a/chatdesk-ui/components/sidebar/items/tools/create-tool.tsx b/chatdesk-ui/components/sidebar/items/tools/create-tool.tsx new file mode 100644 index 0000000..54300e0 --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/tools/create-tool.tsx @@ -0,0 +1,175 @@ +import { SidebarCreateItem } from "@/components/sidebar/items/all/sidebar-create-item" +import { Input } from "@/components/ui/input" +import { Label } from "@/components/ui/label" +import { TextareaAutosize } from "@/components/ui/textarea-autosize" +import { ChatbotUIContext } from "@/context/context" +import { TOOL_DESCRIPTION_MAX, TOOL_NAME_MAX } from "@/db/limits" +import { validateOpenAPI } from "@/lib/openapi-conversion" +import { TablesInsert } from "@/supabase/types" +import { FC, useContext, useState } from "react" +import { useTranslation } from 'react-i18next' + +interface CreateToolProps { + isOpen: boolean + onOpenChange: (isOpen: boolean) => void +} + +export const CreateTool: FC = ({ isOpen, onOpenChange }) => { + const { t } = useTranslation() + + const { profile, selectedWorkspace } = useContext(ChatbotUIContext) + + const [name, setName] = useState("") + const [isTyping, setIsTyping] = useState(false) + const [description, setDescription] = useState("") + const [url, setUrl] = useState("") + const [customHeaders, setCustomHeaders] = useState("") + const [schema, setSchema] = useState("") + const [schemaError, setSchemaError] = useState("") + + if (!profile || !selectedWorkspace) return null + + return ( + + } + isOpen={isOpen} + isTyping={isTyping} + renderInputs={() => ( + <> +
+ + + setName(e.target.value)} + maxLength={TOOL_NAME_MAX} + /> +
+ +
+ + + setDescription(e.target.value)} + maxLength={TOOL_DESCRIPTION_MAX} + /> +
+ + {/*
+ + + setUrl(e.target.value)} + /> +
*/} + + {/*
+
+ + + +
+ +
+ + + +
+ +
+ + + +
+
*/} + +
+ + + +
+ +
+ + + { + setSchema(value) + + try { + const parsedSchema = JSON.parse(value) + validateOpenAPI(parsedSchema) + .then(() => setSchemaError("")) // Clear error if validation is successful + .catch(error => setSchemaError(error.message)) // Set specific validation error message + } catch (error) { + setSchemaError("Invalid JSON format") // Set error for invalid JSON format + } + }} + minRows={15} + /> + +
{schemaError}
+
+ + )} + onOpenChange={onOpenChange} + /> + ) +} diff --git a/chatdesk-ui/components/sidebar/items/tools/tool-item.tsx b/chatdesk-ui/components/sidebar/items/tools/tool-item.tsx new file mode 100644 index 0000000..d1fe5d8 --- /dev/null +++ b/chatdesk-ui/components/sidebar/items/tools/tool-item.tsx @@ -0,0 +1,169 @@ +import { Input } from "@/components/ui/input" +import { Label } from "@/components/ui/label" +import { TextareaAutosize } from "@/components/ui/textarea-autosize" +import { TOOL_DESCRIPTION_MAX, TOOL_NAME_MAX } from "@/db/limits" +import { validateOpenAPI } from "@/lib/openapi-conversion" +import { Tables } from "@/supabase/types" +import { IconBolt } from "@tabler/icons-react" +import { FC, useState } from "react" +import { SidebarItem } from "../all/sidebar-display-item" +import { useTranslation } from 'react-i18next' + +interface ToolItemProps { + tool: Tables<"tools"> +} + +export const ToolItem: FC = ({ tool }) => { + const { t } = useTranslation() + + const [name, setName] = useState(tool.name) + const [isTyping, setIsTyping] = useState(false) + const [description, setDescription] = useState(tool.description) + const [url, setUrl] = useState(tool.url) + const [customHeaders, setCustomHeaders] = useState( + tool.custom_headers as string + ) + const [schema, setSchema] = useState(tool.schema as string) + const [schemaError, setSchemaError] = useState("") + + return ( + } + updateState={{ + name, + description, + url, + custom_headers: customHeaders, + schema + }} + renderInputs={() => ( + <> +
+ + + setName(e.target.value)} + maxLength={TOOL_NAME_MAX} + /> +
+ +
+ + + setDescription(e.target.value)} + maxLength={TOOL_DESCRIPTION_MAX} + /> +
+ + {/*
+ + + setUrl(e.target.value)} + /> +
*/} + + {/*
+
+ + + +
+ +
+ + + +
+ +
+ + + +
+
*/} + +
+ + + +
+ +
+ + + { + setSchema(value) + + try { + const parsedSchema = JSON.parse(value) + validateOpenAPI(parsedSchema) + .then(() => setSchemaError("")) // Clear error if validation is successful + .catch(error => setSchemaError(error.message)) // Set specific validation error message + } catch (error) { + setSchemaError("Invalid JSON format") // Set error for invalid JSON format + } + }} + minRows={15} + /> + +
{schemaError}
+
+ + )} + /> + ) +} diff --git a/chatdesk-ui/components/sidebar/sidebar-content.tsx b/chatdesk-ui/components/sidebar/sidebar-content.tsx new file mode 100644 index 0000000..1c5bd08 --- /dev/null +++ b/chatdesk-ui/components/sidebar/sidebar-content.tsx @@ -0,0 +1,50 @@ +import { Tables } from "@/supabase/types" +import { ContentType, DataListType } from "@/types" +import { FC, useState } from "react" +import { SidebarCreateButtons } from "./sidebar-create-buttons" +import { SidebarDataList } from "./sidebar-data-list" +import { SidebarSearch } from "./sidebar-search" + +interface SidebarContentProps { + contentType: ContentType + data: DataListType + folders: Tables<"folders">[] +} + +export const SidebarContent: FC = ({ + contentType, + data, + folders +}) => { + const [searchTerm, setSearchTerm] = useState("") + + const filteredData: any = data.filter(item => + item.name.toLowerCase().includes(searchTerm.toLowerCase()) + ) + + return ( + // Subtract 50px for the height of the workspace settings +
+
+ 0} + /> +
+ +
+ +
+ + +
+ ) +} diff --git a/chatdesk-ui/components/sidebar/sidebar-create-buttons.tsx b/chatdesk-ui/components/sidebar/sidebar-create-buttons.tsx new file mode 100644 index 0000000..b164cd3 --- /dev/null +++ b/chatdesk-ui/components/sidebar/sidebar-create-buttons.tsx @@ -0,0 +1,177 @@ +import { useChatHandler } from "@/components/chat/chat-hooks/use-chat-handler" +import { ChatbotUIContext } from "@/context/context" +import { createFolder } from "@/db/folders" +import { ContentType } from "@/types" +import { IconFolderPlus, IconPlus } from "@tabler/icons-react" +import { FC, useContext, useState } from "react" +import { Button } from "../ui/button" +import { CreateAssistant } from "./items/assistants/create-assistant" +import { CreateCollection } from "./items/collections/create-collection" +import { CreateFile } from "./items/files/create-file" +import { CreateModel } from "./items/models/create-model" +import { CreatePreset } from "./items/presets/create-preset" +import { CreatePrompt } from "./items/prompts/create-prompt" +import { CreateTool } from "./items/tools/create-tool" + +import { useTranslation } from "react-i18next"; + +interface SidebarCreateButtonsProps { + contentType: ContentType + hasData: boolean +} + +export const SidebarCreateButtons: FC = ({ + contentType, + hasData +}) => { + + const { t, i18n } = useTranslation(); + + const { profile, selectedWorkspace, folders, setFolders } = + useContext(ChatbotUIContext) + const { handleNewChat } = useChatHandler() + + const [isCreatingPrompt, setIsCreatingPrompt] = useState(false) + const [isCreatingPreset, setIsCreatingPreset] = useState(false) + const [isCreatingFile, setIsCreatingFile] = useState(false) + const [isCreatingCollection, setIsCreatingCollection] = useState(false) + const [isCreatingAssistant, setIsCreatingAssistant] = useState(false) + const [isCreatingTool, setIsCreatingTool] = useState(false) + const [isCreatingModel, setIsCreatingModel] = useState(false) + + const handleCreateFolder = async () => { + if (!profile) return + if (!selectedWorkspace) return + + const createdFolder = await createFolder({ + user_id: profile.user_id, + workspace_id: selectedWorkspace.id, + name: "New Folder", + description: "", + type: contentType + }) + setFolders([...folders, createdFolder]) + } + + const getCreateFunction = () => { + switch (contentType) { + case "chats": + return async () => { + handleNewChat() + } + + case "presets": + return async () => { + setIsCreatingPreset(true) + } + + case "prompts": + return async () => { + setIsCreatingPrompt(true) + } + + case "files": + return async () => { + setIsCreatingFile(true) + } + + case "collections": + return async () => { + setIsCreatingCollection(true) + } + + case "assistants": + return async () => { + setIsCreatingAssistant(true) + } + + case "tools": + return async () => { + setIsCreatingTool(true) + } + + case "models": + return async () => { + setIsCreatingModel(true) + } + + default: + break + } + } + + // 判断需要大写首字母的语言 + const needsUpperCaseFirstLetter = (language: string) => { + const languagesRequiringUpperCase = ['en', 'de', 'fr', 'es', 'it']; // 其他需要大写首字母的语言 + return languagesRequiringUpperCase.includes(language); + }; + + // 对动态内容进行首字母大写的处理 + const getCapitalizedContentType = (contentType: string, language: string) => { + if (needsUpperCaseFirstLetter(language)) { + return contentType.charAt(0).toUpperCase() + contentType.slice(1, contentType.length - 1); // 保留去掉最后一个字符 + } + return contentType; // 不需要大写的语言返回原始值 + }; + + return ( +
+ + + {hasData && ( + + )} + + {isCreatingPrompt && ( + + )} + + {isCreatingPreset && ( + + )} + + {isCreatingFile && ( + + )} + + {isCreatingCollection && ( + + )} + + {isCreatingAssistant && ( + + )} + + {isCreatingTool && ( + + )} + + {isCreatingModel && ( + + )} +
+ ) +} diff --git a/chatdesk-ui/components/sidebar/sidebar-data-list.tsx b/chatdesk-ui/components/sidebar/sidebar-data-list.tsx new file mode 100644 index 0000000..cea85db --- /dev/null +++ b/chatdesk-ui/components/sidebar/sidebar-data-list.tsx @@ -0,0 +1,372 @@ +import { ChatbotUIContext } from "@/context/context" +import { updateAssistant } from "@/db/assistants" +import { updateChat } from "@/db/chats" +import { updateCollection } from "@/db/collections" +import { updateFile } from "@/db/files" +import { updateModel } from "@/db/models" +import { updatePreset } from "@/db/presets" +import { updatePrompt } from "@/db/prompts" +import { updateTool } from "@/db/tools" +import { cn } from "@/lib/utils" +import { Tables } from "@/supabase/types" +import { ContentType, DataItemType, DataListType } from "@/types" +import { FC, useContext, useEffect, useRef, useState } from "react" +import { Separator } from "../ui/separator" +import { AssistantItem } from "./items/assistants/assistant-item" +import { ChatItem } from "./items/chat/chat-item" +import { CollectionItem } from "./items/collections/collection-item" +import { FileItem } from "./items/files/file-item" +import { Folder } from "./items/folders/folder-item" +import { ModelItem } from "./items/models/model-item" +import { PresetItem } from "./items/presets/preset-item" +import { PromptItem } from "./items/prompts/prompt-item" +import { ToolItem } from "./items/tools/tool-item" + +import { useTranslation } from "react-i18next"; + +interface SidebarDataListProps { + contentType: ContentType + data: DataListType + folders: Tables<"folders">[] +} + +export const SidebarDataList: FC = ({ + contentType, + data, + folders +}) => { + + const { t } = useTranslation(); + + const dateCategories = [ + { key: "Today", label: t("side.chatTime.Today") }, + { key: "Yesterday", label: t("side.chatTime.Yesterday") }, + { key: "PreviousWeek", label: t("side.chatTime.PreviousWeek") }, + { key: "Older", label: t("side.chatTime.Older") } + ] + + const { + setChats, + setPresets, + setPrompts, + setFiles, + setCollections, + setAssistants, + setTools, + setModels + } = useContext(ChatbotUIContext) + + const divRef = useRef(null) + + const [isOverflowing, setIsOverflowing] = useState(false) + const [isDragOver, setIsDragOver] = useState(false) + + const getDataListComponent = ( + contentType: ContentType, + item: DataItemType + ) => { + switch (contentType) { + case "chats": + return } /> + + case "presets": + return } /> + + case "prompts": + return } /> + + case "files": + return } /> + + case "collections": + return ( + } + /> + ) + + case "assistants": + return ( + } + /> + ) + + case "tools": + return } /> + + case "models": + return } /> + + default: + return null + } + } + + const getSortedData = ( + data: any, + dateCategory: "Today" | "Yesterday" | "PreviousWeek" | "Older" + ) => { + const now = new Date() + const todayStart = new Date(now.setHours(0, 0, 0, 0)) + const yesterdayStart = new Date( + new Date().setDate(todayStart.getDate() - 1) + ) + const oneWeekAgoStart = new Date( + new Date().setDate(todayStart.getDate() - 7) + ) + + return data + .filter((item: any) => { + const itemDate = new Date(item.updated_at || item.created_at) + switch (dateCategory) { + case "Today": + return itemDate >= todayStart + case "Yesterday": + return itemDate >= yesterdayStart && itemDate < todayStart + case "PreviousWeek": + return itemDate >= oneWeekAgoStart && itemDate < yesterdayStart + case "Older": + return itemDate < oneWeekAgoStart + default: + return true + } + }) + .sort( + ( + a: { updated_at: string; created_at: string }, + b: { updated_at: string; created_at: string } + ) => + new Date(b.updated_at || b.created_at).getTime() - + new Date(a.updated_at || a.created_at).getTime() + ) + } + + const updateFunctions = { + chats: updateChat, + presets: updatePreset, + prompts: updatePrompt, + files: updateFile, + collections: updateCollection, + assistants: updateAssistant, + tools: updateTool, + models: updateModel + } + + const stateUpdateFunctions = { + chats: setChats, + presets: setPresets, + prompts: setPrompts, + files: setFiles, + collections: setCollections, + assistants: setAssistants, + tools: setTools, + models: setModels + } + + const updateFolder = async (itemId: string, folderId: string | null) => { + const item: any = data.find(item => item.id === itemId) + + if (!item) return null + + const updateFunction = updateFunctions[contentType] + const setStateFunction = stateUpdateFunctions[contentType] + + if (!updateFunction || !setStateFunction) return + + const updatedItem = await updateFunction(item.id, { + folder_id: folderId + }) + + setStateFunction((items: any) => + items.map((item: any) => + item.id === updatedItem.id ? updatedItem : item + ) + ) + } + + const handleDragEnter = (e: React.DragEvent) => { + e.preventDefault() + setIsDragOver(true) + } + + const handleDragLeave = (e: React.DragEvent) => { + e.preventDefault() + setIsDragOver(false) + } + + const handleDragStart = (e: React.DragEvent, id: string) => { + e.dataTransfer.setData("text/plain", id) + } + + const handleDragOver = (e: React.DragEvent) => { + e.preventDefault() + } + + const handleDrop = (e: React.DragEvent) => { + e.preventDefault() + + const target = e.target as Element + + if (!target.closest("#folder")) { + const itemId = e.dataTransfer.getData("text/plain") + updateFolder(itemId, null) + } + + setIsDragOver(false) + } + + useEffect(() => { + if (divRef.current) { + setIsOverflowing( + divRef.current.scrollHeight > divRef.current.clientHeight + ) + } + }, [data]) + + const dataWithFolders = data.filter(item => item.folder_id) + const dataWithoutFolders = data.filter(item => item.folder_id === null) + + // 获取 "No {contentType}" 的国际化文本 + const getNoContentTypeText = (contentType: string) => { + const translatedContentType = t(`contentType.${contentType}`); + return t('side.sidebarNoContentType', { contentType: translatedContentType }) + "."; + }; + + return ( + <> +
+ {data.length === 0 && ( +
+
+ {getNoContentTypeText(contentType)} +
+
+ )} + + {(dataWithFolders.length > 0 || dataWithoutFolders.length > 0) && ( +
+ {folders.map(folder => ( + + {dataWithFolders + .filter(item => item.folder_id === folder.id) + .map(item => ( +
handleDragStart(e, item.id)} + > + {getDataListComponent(contentType, item)} +
+ ))} +
+ ))} + + {folders.length > 0 && } + + {contentType === "chats" ? ( + <> + {/* {["Today", "Yesterday", "Previous Week", "Older"].map( + dateCategory => { + const sortedData = getSortedData( + dataWithoutFolders, + dateCategory as + | "Today" + | "Yesterday" + | "Previous Week" + | "Older" + ) */} + {dateCategories.map(({ key, label }) => { + const sortedData = getSortedData( + dataWithoutFolders, + key as "Today" | "Yesterday" | "PreviousWeek" | "Older" + ) + + return ( + // sortedData.length > 0 && ( + //
+ //
+ // {dateCategory} + //
+ sortedData.length > 0 && ( +
{/* ✅ 用 key 替代已删的变量 */} +
+ {label} {/* ✅ 用 label 显示翻译文本 */} +
+ +
+ {sortedData.map((item: any) => ( +
handleDragStart(e, item.id)} + > + {getDataListComponent(contentType, item)} +
+ ))} +
+
+ ) + ) + } + )} + + ) : ( +
+ {dataWithoutFolders.map(item => { + return ( +
handleDragStart(e, item.id)} + > + {getDataListComponent(contentType, item)} +
+ ) + })} +
+ )} +
+ )} +
+ +
+ + ) +} diff --git a/chatdesk-ui/components/sidebar/sidebar-search.tsx b/chatdesk-ui/components/sidebar/sidebar-search.tsx new file mode 100644 index 0000000..745ab2e --- /dev/null +++ b/chatdesk-ui/components/sidebar/sidebar-search.tsx @@ -0,0 +1,28 @@ +import { ContentType } from "@/types" +import { FC } from "react" +import { Input } from "../ui/input" + +import { useTranslation } from "react-i18next"; + +interface SidebarSearchProps { + contentType: ContentType + searchTerm: string + setSearchTerm: Function +} + +export const SidebarSearch: FC = ({ + contentType, + searchTerm, + setSearchTerm +}) => { + + const { t } = useTranslation(); + + return ( + setSearchTerm(e.target.value)} + /> + ) +} diff --git a/chatdesk-ui/components/sidebar/sidebar-switch-item.tsx b/chatdesk-ui/components/sidebar/sidebar-switch-item.tsx new file mode 100644 index 0000000..455c782 --- /dev/null +++ b/chatdesk-ui/components/sidebar/sidebar-switch-item.tsx @@ -0,0 +1,36 @@ +import { ContentType } from "@/types" +import { FC } from "react" +import { TabsTrigger } from "../ui/tabs" +import { WithTooltip } from "../ui/with-tooltip" +import { useTranslation } from 'react-i18next' + +interface SidebarSwitchItemProps { + contentType: ContentType + icon: React.ReactNode + onContentTypeChange: (contentType: ContentType) => void +} + +export const SidebarSwitchItem: FC = ({ + contentType, + icon, + onContentTypeChange +}) => { + const { t } = useTranslation() + return ( + {contentType[0].toUpperCase() + contentType.substring(1)}
+ // } + display={
{t(`contentType.${contentType}`)}
} + trigger={ + onContentTypeChange(contentType as ContentType)} + > + {icon} + + } + /> + ) +} diff --git a/chatdesk-ui/components/sidebar/sidebar-switcher.tsx b/chatdesk-ui/components/sidebar/sidebar-switcher.tsx new file mode 100644 index 0000000..454d942 --- /dev/null +++ b/chatdesk-ui/components/sidebar/sidebar-switcher.tsx @@ -0,0 +1,96 @@ +import { ContentType } from "@/types" +import { + IconAdjustmentsHorizontal, + IconBolt, + IconBooks, + IconFile, + IconMessage, + IconPencil, + IconRobotFace, + IconSparkles +} from "@tabler/icons-react" +import { FC } from "react" +import { TabsList } from "../ui/tabs" +import { WithTooltip } from "../ui/with-tooltip" +import { ProfileSettings } from "../utility/profile-settings" +import { SidebarSwitchItem } from "./sidebar-switch-item" + +import { useTranslation } from 'react-i18next' + +export const SIDEBAR_ICON_SIZE = 28 + +interface SidebarSwitcherProps { + onContentTypeChange: (contentType: ContentType) => void +} + +export const SidebarSwitcher: FC = ({ + onContentTypeChange +}) => { + const { t } = useTranslation() + return ( +
+ + } + contentType="chats" + onContentTypeChange={onContentTypeChange} + /> + + } + contentType="presets" + onContentTypeChange={onContentTypeChange} + /> + + } + contentType="prompts" + onContentTypeChange={onContentTypeChange} + /> + + } + contentType="models" + onContentTypeChange={onContentTypeChange} + /> + + } + contentType="files" + onContentTypeChange={onContentTypeChange} + /> + + } + contentType="collections" + onContentTypeChange={onContentTypeChange} + /> + + } + contentType="assistants" + onContentTypeChange={onContentTypeChange} + /> + + } + contentType="tools" + onContentTypeChange={onContentTypeChange} + /> + + +
+ {/* TODO */} + {/* Import
} trigger={} /> */} + + {/* TODO */} + {/* */} + + {t("side.profileSettings")}
} + trigger={} + /> +
+ + ) +} diff --git a/chatdesk-ui/components/sidebar/sidebar.tsx b/chatdesk-ui/components/sidebar/sidebar.tsx new file mode 100644 index 0000000..a729a0f --- /dev/null +++ b/chatdesk-ui/components/sidebar/sidebar.tsx @@ -0,0 +1,116 @@ +import { ChatbotUIContext } from "@/context/context" +import { Tables } from "@/supabase/types" +import { ContentType } from "@/types" +import { FC, useContext } from "react" +import { SIDEBAR_WIDTH } from "../ui/dashboard" +import { TabsContent } from "../ui/tabs" +import { WorkspaceSwitcher } from "../utility/workspace-switcher" +import { WorkspaceSettings } from "../workspace/workspace-settings" +import { SidebarContent } from "./sidebar-content" + +import { useTranslation } from "react-i18next"; + +interface SidebarProps { + contentType: ContentType + showSidebar: boolean +} + +export const Sidebar: FC = ({ contentType, showSidebar }) => { + + const { t } = useTranslation(); + + const { + folders, + chats, + presets, + prompts, + files, + collections, + assistants, + tools, + models + } = useContext(ChatbotUIContext) + + const chatFolders = folders.filter(folder => folder.type === "chats") + const presetFolders = folders.filter(folder => folder.type === "presets") + const promptFolders = folders.filter(folder => folder.type === "prompts") + const filesFolders = folders.filter(folder => folder.type === "files") + const collectionFolders = folders.filter( + folder => folder.type === "collections" + ) + const assistantFolders = folders.filter( + folder => folder.type === "assistants" + ) + const toolFolders = folders.filter(folder => folder.type === "tools") + const modelFolders = folders.filter(folder => folder.type === "models") + + const renderSidebarContent = ( + contentType: ContentType, + data: any[], + folders: Tables<"folders">[] + ) => { + return ( + + ) + } + + return ( + +
+
+ + + +
+ + {(() => { + switch (contentType) { + case "chats": + return renderSidebarContent("chats", chats, chatFolders) + + case "presets": + return renderSidebarContent("presets", presets, presetFolders) + + case "prompts": + return renderSidebarContent("prompts", prompts, promptFolders) + + case "files": + return renderSidebarContent("files", files, filesFolders) + + case "collections": + return renderSidebarContent( + "collections", + collections, + collectionFolders + ) + + case "assistants": + return renderSidebarContent( + "assistants", + assistants, + assistantFolders + ) + + case "tools": + return renderSidebarContent("tools", tools, toolFolders) + + case "models": + return renderSidebarContent("models", models, modelFolders) + + default: + return null + } + })()} +
+
+ ) +} diff --git a/chatdesk-ui/components/ui/accordion.tsx b/chatdesk-ui/components/ui/accordion.tsx new file mode 100644 index 0000000..791ca2c --- /dev/null +++ b/chatdesk-ui/components/ui/accordion.tsx @@ -0,0 +1,58 @@ +"use client" + +import * as React from "react" +import * as AccordionPrimitive from "@radix-ui/react-accordion" +import { ChevronDown } from "lucide-react" + +import { cn } from "@/lib/utils" + +const Accordion = AccordionPrimitive.Root + +const AccordionItem = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +AccordionItem.displayName = "AccordionItem" + +const AccordionTrigger = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, children, ...props }, ref) => ( + + svg]:rotate-180", + className + )} + {...props} + > + {children} + + + +)) +AccordionTrigger.displayName = AccordionPrimitive.Trigger.displayName + +const AccordionContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, children, ...props }, ref) => ( + +
{children}
+
+)) + +AccordionContent.displayName = AccordionPrimitive.Content.displayName + +export { Accordion, AccordionItem, AccordionTrigger, AccordionContent } diff --git a/chatdesk-ui/components/ui/advanced-settings.tsx b/chatdesk-ui/components/ui/advanced-settings.tsx new file mode 100644 index 0000000..af6a517 --- /dev/null +++ b/chatdesk-ui/components/ui/advanced-settings.tsx @@ -0,0 +1,45 @@ +import { + Collapsible, + CollapsibleContent, + CollapsibleTrigger +} from "@/components/ui/collapsible" +import { IconChevronDown, IconChevronRight } from "@tabler/icons-react" +import { FC, useState } from "react" + +import { useTranslation } from 'react-i18next' + +interface AdvancedSettingsProps { + children: React.ReactNode +} + +export const AdvancedSettings: FC = ({ children }) => { + + const { t } = useTranslation() + + const [isOpen, setIsOpen] = useState( + false + // localStorage.getItem("advanced-settings-open") === "true" + ) + + const handleOpenChange = (isOpen: boolean) => { + setIsOpen(isOpen) + // localStorage.setItem("advanced-settings-open", String(isOpen)) + } + + return ( + + +
+
{t("chat.advancedSettings")}
+ {isOpen ? ( + + ) : ( + + )} +
+
+ + {children} +
+ ) +} diff --git a/chatdesk-ui/components/ui/alert-dialog.tsx b/chatdesk-ui/components/ui/alert-dialog.tsx new file mode 100644 index 0000000..d468c39 --- /dev/null +++ b/chatdesk-ui/components/ui/alert-dialog.tsx @@ -0,0 +1,141 @@ +"use client" + +import * as React from "react" +import * as AlertDialogPrimitive from "@radix-ui/react-alert-dialog" + +import { cn } from "@/lib/utils" +import { buttonVariants } from "@/components/ui/button" + +const AlertDialog = AlertDialogPrimitive.Root + +const AlertDialogTrigger = AlertDialogPrimitive.Trigger + +const AlertDialogPortal = AlertDialogPrimitive.Portal + +const AlertDialogOverlay = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +AlertDialogOverlay.displayName = AlertDialogPrimitive.Overlay.displayName + +const AlertDialogContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + + + + +)) +AlertDialogContent.displayName = AlertDialogPrimitive.Content.displayName + +const AlertDialogHeader = ({ + className, + ...props +}: React.HTMLAttributes) => ( +
+) +AlertDialogHeader.displayName = "AlertDialogHeader" + +const AlertDialogFooter = ({ + className, + ...props +}: React.HTMLAttributes) => ( +
+) +AlertDialogFooter.displayName = "AlertDialogFooter" + +const AlertDialogTitle = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +AlertDialogTitle.displayName = AlertDialogPrimitive.Title.displayName + +const AlertDialogDescription = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +AlertDialogDescription.displayName = + AlertDialogPrimitive.Description.displayName + +const AlertDialogAction = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +AlertDialogAction.displayName = AlertDialogPrimitive.Action.displayName + +const AlertDialogCancel = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +AlertDialogCancel.displayName = AlertDialogPrimitive.Cancel.displayName + +export { + AlertDialog, + AlertDialogPortal, + AlertDialogOverlay, + AlertDialogTrigger, + AlertDialogContent, + AlertDialogHeader, + AlertDialogFooter, + AlertDialogTitle, + AlertDialogDescription, + AlertDialogAction, + AlertDialogCancel +} diff --git a/chatdesk-ui/components/ui/alert.tsx b/chatdesk-ui/components/ui/alert.tsx new file mode 100644 index 0000000..588ee66 --- /dev/null +++ b/chatdesk-ui/components/ui/alert.tsx @@ -0,0 +1,59 @@ +import * as React from "react" +import { cva, type VariantProps } from "class-variance-authority" + +import { cn } from "@/lib/utils" + +const alertVariants = cva( + "[&>svg]:text-foreground relative w-full rounded-lg border p-4 [&>svg+div]:translate-y-[-3px] [&>svg]:absolute [&>svg]:left-4 [&>svg]:top-4 [&>svg~*]:pl-7", + { + variants: { + variant: { + default: "bg-background text-foreground", + destructive: + "border-destructive/50 text-destructive dark:border-destructive [&>svg]:text-destructive" + } + }, + defaultVariants: { + variant: "default" + } + } +) + +const Alert = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes & VariantProps +>(({ className, variant, ...props }, ref) => ( +
+)) +Alert.displayName = "Alert" + +const AlertTitle = React.forwardRef< + HTMLParagraphElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)) +AlertTitle.displayName = "AlertTitle" + +const AlertDescription = React.forwardRef< + HTMLParagraphElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)) +AlertDescription.displayName = "AlertDescription" + +export { Alert, AlertTitle, AlertDescription } diff --git a/chatdesk-ui/components/ui/aspect-ratio.tsx b/chatdesk-ui/components/ui/aspect-ratio.tsx new file mode 100644 index 0000000..d6a5226 --- /dev/null +++ b/chatdesk-ui/components/ui/aspect-ratio.tsx @@ -0,0 +1,7 @@ +"use client" + +import * as AspectRatioPrimitive from "@radix-ui/react-aspect-ratio" + +const AspectRatio = AspectRatioPrimitive.Root + +export { AspectRatio } diff --git a/chatdesk-ui/components/ui/avatar.tsx b/chatdesk-ui/components/ui/avatar.tsx new file mode 100644 index 0000000..1cf1283 --- /dev/null +++ b/chatdesk-ui/components/ui/avatar.tsx @@ -0,0 +1,50 @@ +"use client" + +import * as React from "react" +import * as AvatarPrimitive from "@radix-ui/react-avatar" + +import { cn } from "@/lib/utils" + +const Avatar = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +Avatar.displayName = AvatarPrimitive.Root.displayName + +const AvatarImage = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +AvatarImage.displayName = AvatarPrimitive.Image.displayName + +const AvatarFallback = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +AvatarFallback.displayName = AvatarPrimitive.Fallback.displayName + +export { Avatar, AvatarImage, AvatarFallback } diff --git a/chatdesk-ui/components/ui/badge.tsx b/chatdesk-ui/components/ui/badge.tsx new file mode 100644 index 0000000..56f0ea4 --- /dev/null +++ b/chatdesk-ui/components/ui/badge.tsx @@ -0,0 +1,36 @@ +import * as React from "react" +import { cva, type VariantProps } from "class-variance-authority" + +import { cn } from "@/lib/utils" + +const badgeVariants = cva( + "focus:ring-ring inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold transition-colors focus:outline-none focus:ring-2 focus:ring-offset-2", + { + variants: { + variant: { + default: + "bg-primary text-primary-foreground hover:bg-primary/80 border-transparent", + secondary: + "bg-secondary text-secondary-foreground hover:bg-secondary/80 border-transparent", + destructive: + "bg-destructive text-destructive-foreground hover:bg-destructive/80 border-transparent", + outline: "text-foreground" + } + }, + defaultVariants: { + variant: "default" + } + } +) + +export interface BadgeProps + extends React.HTMLAttributes, + VariantProps {} + +function Badge({ className, variant, ...props }: BadgeProps) { + return ( +
+ ) +} + +export { Badge, badgeVariants } diff --git a/chatdesk-ui/components/ui/brand.tsx b/chatdesk-ui/components/ui/brand.tsx new file mode 100644 index 0000000..5af67ef --- /dev/null +++ b/chatdesk-ui/components/ui/brand.tsx @@ -0,0 +1,29 @@ +"use client" + +import Link from "next/link" +import { FC } from "react" +import { ChatbotUISVG } from "../icons/chatbotui-svg" + +import { useTranslation } from 'react-i18next' + +interface BrandProps { + theme?: "dark" | "light" +} + +export const Brand: FC = ({ theme = "dark" }) => { + const { t } = useTranslation() + return ( + +
+ +
+ +
{t("Company Name")}
+ + ) +} diff --git a/chatdesk-ui/components/ui/button.tsx b/chatdesk-ui/components/ui/button.tsx new file mode 100644 index 0000000..fb4fc2f --- /dev/null +++ b/chatdesk-ui/components/ui/button.tsx @@ -0,0 +1,56 @@ +import { Slot } from "@radix-ui/react-slot" +import { cva, type VariantProps } from "class-variance-authority" +import * as React from "react" + +import { cn } from "@/lib/utils" + +const buttonVariants = cva( + "ring-offset-background focus-visible:ring-ring inline-flex items-center justify-center whitespace-nowrap rounded-md text-sm font-medium transition-colors hover:opacity-50 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50", + { + variants: { + variant: { + default: "bg-primary text-primary-foreground hover:bg-primary/90", + destructive: + "bg-destructive text-destructive-foreground hover:bg-destructive/90", + outline: + "border-input bg-background hover:bg-accent hover:text-accent-foreground border", + secondary: + "bg-secondary text-secondary-foreground hover:bg-secondary/80", + ghost: "hover:bg-accent hover:text-accent-foreground", + link: "text-primary underline-offset-4 hover:underline" + }, + size: { + default: "h-10 px-4 py-2", + sm: "h-9 rounded-md px-3", + lg: "h-11 rounded-md px-8", + icon: "size-10" + } + }, + defaultVariants: { + variant: "default", + size: "default" + } + } +) + +export interface ButtonProps + extends React.ButtonHTMLAttributes, + VariantProps { + asChild?: boolean +} + +const Button = React.forwardRef( + ({ className, variant, size, asChild = false, ...props }, ref) => { + const Comp = asChild ? Slot : "button" + return ( + + ) + } +) +Button.displayName = "Button" + +export { Button, buttonVariants } diff --git a/chatdesk-ui/components/ui/calendar.tsx b/chatdesk-ui/components/ui/calendar.tsx new file mode 100644 index 0000000..07c8aa4 --- /dev/null +++ b/chatdesk-ui/components/ui/calendar.tsx @@ -0,0 +1,66 @@ +"use client" + +import * as React from "react" +import { ChevronLeft, ChevronRight } from "lucide-react" +import { DayPicker } from "react-day-picker" + +import { cn } from "@/lib/utils" +import { buttonVariants } from "@/components/ui/button" + +export type CalendarProps = React.ComponentProps + +function Calendar({ + className, + classNames, + showOutsideDays = true, + ...props +}: CalendarProps) { + return ( + , + IconRight: ({ ...props }) => + }} + {...props} + /> + ) +} +Calendar.displayName = "Calendar" + +export { Calendar } diff --git a/chatdesk-ui/components/ui/card.tsx b/chatdesk-ui/components/ui/card.tsx new file mode 100644 index 0000000..a26fd5d --- /dev/null +++ b/chatdesk-ui/components/ui/card.tsx @@ -0,0 +1,79 @@ +import * as React from "react" + +import { cn } from "@/lib/utils" + +const Card = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)) +Card.displayName = "Card" + +const CardHeader = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)) +CardHeader.displayName = "CardHeader" + +const CardTitle = React.forwardRef< + HTMLParagraphElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +

+)) +CardTitle.displayName = "CardTitle" + +const CardDescription = React.forwardRef< + HTMLParagraphElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +

+)) +CardDescription.displayName = "CardDescription" + +const CardContent = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +

+)) +CardContent.displayName = "CardContent" + +const CardFooter = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => ( +
+)) +CardFooter.displayName = "CardFooter" + +export { Card, CardHeader, CardFooter, CardTitle, CardDescription, CardContent } diff --git a/chatdesk-ui/components/ui/chat-settings-form.tsx b/chatdesk-ui/components/ui/chat-settings-form.tsx new file mode 100644 index 0000000..1aeb331 --- /dev/null +++ b/chatdesk-ui/components/ui/chat-settings-form.tsx @@ -0,0 +1,261 @@ +"use client" + +import { ChatbotUIContext } from "@/context/context" +import { CHAT_SETTING_LIMITS } from "@/lib/chat-setting-limits" +import { ChatSettings } from "@/types" +import { IconInfoCircle } from "@tabler/icons-react" +import { FC, useContext } from "react" +import { ModelSelect } from "../models/model-select" +import { AdvancedSettings } from "./advanced-settings" +import { Checkbox } from "./checkbox" +import { Label } from "./label" +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue +} from "./select" +import { Slider } from "./slider" +import { TextareaAutosize } from "./textarea-autosize" +import { WithTooltip } from "./with-tooltip" + +import { useTranslation } from 'react-i18next' + +interface ChatSettingsFormProps { + chatSettings: ChatSettings + onChangeChatSettings: (value: ChatSettings) => void + useAdvancedDropdown?: boolean + showTooltip?: boolean +} + +export const ChatSettingsForm: FC = ({ + chatSettings, + onChangeChatSettings, + useAdvancedDropdown = true, + showTooltip = true +}) => { + + const { t } = useTranslation() + + const { profile, models } = useContext(ChatbotUIContext) + + if (!profile) return null + + return ( +
+
+ + + { + onChangeChatSettings({ ...chatSettings, model }) + }} + /> +
+ +
+ + + { + onChangeChatSettings({ ...chatSettings, prompt }) + }} + value={chatSettings.prompt} + minRows={3} + maxRows={6} + /> +
+ + {useAdvancedDropdown ? ( + + + + ) : ( +
+ +
+ )} +
+ ) +} + +interface AdvancedContentProps { + chatSettings: ChatSettings + onChangeChatSettings: (value: ChatSettings) => void + showTooltip: boolean +} + +const AdvancedContent: FC = ({ + chatSettings, + onChangeChatSettings, + showTooltip +}) => { + + const { t } = useTranslation() + + const { profile, selectedWorkspace, availableOpenRouterModels, models } = + useContext(ChatbotUIContext) + + const isCustomModel = models.some( + model => model.model_id === chatSettings.model + ) + + function findOpenRouterModel(modelId: string) { + return availableOpenRouterModels.find(model => model.modelId === modelId) + } + + const MODEL_LIMITS = CHAT_SETTING_LIMITS[chatSettings.model] || { + MIN_TEMPERATURE: 0, + MAX_TEMPERATURE: 1, + MAX_CONTEXT_LENGTH: + findOpenRouterModel(chatSettings.model)?.maxContext || 4096 + } + + return ( +
+
+ + + { + onChangeChatSettings({ + ...chatSettings, + temperature: temperature[0] + }) + }} + min={MODEL_LIMITS.MIN_TEMPERATURE} + max={MODEL_LIMITS.MAX_TEMPERATURE} + step={0.01} + /> +
+ +
+ + + { + onChangeChatSettings({ + ...chatSettings, + contextLength: contextLength[0] + }) + }} + min={0} + max={ + isCustomModel + ? models.find(model => model.model_id === chatSettings.model) + ?.context_length + : MODEL_LIMITS.MAX_CONTEXT_LENGTH + } + step={1} + /> +
+ +
+ + onChangeChatSettings({ + ...chatSettings, + includeProfileContext: value + }) + } + /> + + + + {showTooltip && ( + + {profile?.profile_context || t("chat.noProfileContext")} +
+ } + trigger={ + + } + /> + )} +
+ +
+ + onChangeChatSettings({ + ...chatSettings, + includeWorkspaceInstructions: value + }) + } + /> + + + + {showTooltip && ( + + {selectedWorkspace?.instructions || + "No workspace instructions."} +
+ } + trigger={ + + } + /> + )} +
+ +
+ + + +
+
+ ) +} diff --git a/chatdesk-ui/components/ui/checkbox.tsx b/chatdesk-ui/components/ui/checkbox.tsx new file mode 100644 index 0000000..6abd7f8 --- /dev/null +++ b/chatdesk-ui/components/ui/checkbox.tsx @@ -0,0 +1,30 @@ +"use client" + +import * as React from "react" +import * as CheckboxPrimitive from "@radix-ui/react-checkbox" +import { Check } from "lucide-react" + +import { cn } from "@/lib/utils" + +const Checkbox = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + + + + + +)) +Checkbox.displayName = CheckboxPrimitive.Root.displayName + +export { Checkbox } diff --git a/chatdesk-ui/components/ui/collapsible.tsx b/chatdesk-ui/components/ui/collapsible.tsx new file mode 100644 index 0000000..9fa4894 --- /dev/null +++ b/chatdesk-ui/components/ui/collapsible.tsx @@ -0,0 +1,11 @@ +"use client" + +import * as CollapsiblePrimitive from "@radix-ui/react-collapsible" + +const Collapsible = CollapsiblePrimitive.Root + +const CollapsibleTrigger = CollapsiblePrimitive.CollapsibleTrigger + +const CollapsibleContent = CollapsiblePrimitive.CollapsibleContent + +export { Collapsible, CollapsibleTrigger, CollapsibleContent } diff --git a/chatdesk-ui/components/ui/command.tsx b/chatdesk-ui/components/ui/command.tsx new file mode 100644 index 0000000..bf47537 --- /dev/null +++ b/chatdesk-ui/components/ui/command.tsx @@ -0,0 +1,155 @@ +"use client" + +import * as React from "react" +import { type DialogProps } from "@radix-ui/react-dialog" +import { Command as CommandPrimitive } from "cmdk" +import { Search } from "lucide-react" + +import { cn } from "@/lib/utils" +import { Dialog, DialogContent } from "@/components/ui/dialog" + +const Command = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +Command.displayName = CommandPrimitive.displayName + +interface CommandDialogProps extends DialogProps {} + +const CommandDialog = ({ children, ...props }: CommandDialogProps) => { + return ( + + + + {children} + + + + ) +} + +const CommandInput = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( +
+ + +
+)) + +CommandInput.displayName = CommandPrimitive.Input.displayName + +const CommandList = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) + +CommandList.displayName = CommandPrimitive.List.displayName + +const CommandEmpty = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>((props, ref) => ( + +)) + +CommandEmpty.displayName = CommandPrimitive.Empty.displayName + +const CommandGroup = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) + +CommandGroup.displayName = CommandPrimitive.Group.displayName + +const CommandSeparator = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +CommandSeparator.displayName = CommandPrimitive.Separator.displayName + +const CommandItem = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) + +CommandItem.displayName = CommandPrimitive.Item.displayName + +const CommandShortcut = ({ + className, + ...props +}: React.HTMLAttributes) => { + return ( + + ) +} +CommandShortcut.displayName = "CommandShortcut" + +export { + Command, + CommandDialog, + CommandInput, + CommandList, + CommandEmpty, + CommandGroup, + CommandItem, + CommandShortcut, + CommandSeparator +} diff --git a/chatdesk-ui/components/ui/context-menu.tsx b/chatdesk-ui/components/ui/context-menu.tsx new file mode 100644 index 0000000..05b7f09 --- /dev/null +++ b/chatdesk-ui/components/ui/context-menu.tsx @@ -0,0 +1,200 @@ +"use client" + +import * as React from "react" +import * as ContextMenuPrimitive from "@radix-ui/react-context-menu" +import { Check, ChevronRight, Circle } from "lucide-react" + +import { cn } from "@/lib/utils" + +const ContextMenu = ContextMenuPrimitive.Root + +const ContextMenuTrigger = ContextMenuPrimitive.Trigger + +const ContextMenuGroup = ContextMenuPrimitive.Group + +const ContextMenuPortal = ContextMenuPrimitive.Portal + +const ContextMenuSub = ContextMenuPrimitive.Sub + +const ContextMenuRadioGroup = ContextMenuPrimitive.RadioGroup + +const ContextMenuSubTrigger = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef & { + inset?: boolean + } +>(({ className, inset, children, ...props }, ref) => ( + + {children} + + +)) +ContextMenuSubTrigger.displayName = ContextMenuPrimitive.SubTrigger.displayName + +const ContextMenuSubContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +ContextMenuSubContent.displayName = ContextMenuPrimitive.SubContent.displayName + +const ContextMenuContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + + + +)) +ContextMenuContent.displayName = ContextMenuPrimitive.Content.displayName + +const ContextMenuItem = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef & { + inset?: boolean + } +>(({ className, inset, ...props }, ref) => ( + +)) +ContextMenuItem.displayName = ContextMenuPrimitive.Item.displayName + +const ContextMenuCheckboxItem = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, children, checked, ...props }, ref) => ( + + + + + + + {children} + +)) +ContextMenuCheckboxItem.displayName = + ContextMenuPrimitive.CheckboxItem.displayName + +const ContextMenuRadioItem = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, children, ...props }, ref) => ( + + + + + + + {children} + +)) +ContextMenuRadioItem.displayName = ContextMenuPrimitive.RadioItem.displayName + +const ContextMenuLabel = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef & { + inset?: boolean + } +>(({ className, inset, ...props }, ref) => ( + +)) +ContextMenuLabel.displayName = ContextMenuPrimitive.Label.displayName + +const ContextMenuSeparator = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +ContextMenuSeparator.displayName = ContextMenuPrimitive.Separator.displayName + +const ContextMenuShortcut = ({ + className, + ...props +}: React.HTMLAttributes) => { + return ( + + ) +} +ContextMenuShortcut.displayName = "ContextMenuShortcut" + +export { + ContextMenu, + ContextMenuTrigger, + ContextMenuContent, + ContextMenuItem, + ContextMenuCheckboxItem, + ContextMenuRadioItem, + ContextMenuLabel, + ContextMenuSeparator, + ContextMenuShortcut, + ContextMenuGroup, + ContextMenuPortal, + ContextMenuSub, + ContextMenuSubContent, + ContextMenuSubTrigger, + ContextMenuRadioGroup +} diff --git a/chatdesk-ui/components/ui/dashboard.tsx b/chatdesk-ui/components/ui/dashboard.tsx new file mode 100644 index 0000000..1b27ff2 --- /dev/null +++ b/chatdesk-ui/components/ui/dashboard.tsx @@ -0,0 +1,138 @@ +"use client" + +import { Sidebar } from "@/components/sidebar/sidebar" +import { SidebarSwitcher } from "@/components/sidebar/sidebar-switcher" +import { Button } from "@/components/ui/button" +import { Tabs } from "@/components/ui/tabs" +import useHotkey from "@/lib/hooks/use-hotkey" +import { cn } from "@/lib/utils" +import { ContentType } from "@/types" +import { IconChevronCompactRight } from "@tabler/icons-react" +import { usePathname, useRouter, useSearchParams } from "next/navigation" +import { FC, useState } from "react" +import { useSelectFileHandler } from "../chat/chat-hooks/use-select-file-handler" +import { CommandK } from "../utility/command-k" + +import { useTranslation } from 'react-i18next' + +export const SIDEBAR_WIDTH = 350 + +interface DashboardProps { + children: React.ReactNode +} + +export const Dashboard: FC = ({ children }) => { + + const { t } = useTranslation() + + useHotkey("s", () => setShowSidebar(prevState => !prevState)) + + const pathname = usePathname() + const router = useRouter() + const searchParams = useSearchParams() + const tabValue = searchParams.get("tab") || "chats" + + const { handleSelectDeviceFile } = useSelectFileHandler() + + const [contentType, setContentType] = useState( + tabValue as ContentType + ) + const [showSidebar, setShowSidebar] = useState( + localStorage.getItem("showSidebar") === "true" + ) + const [isDragging, setIsDragging] = useState(false) + + const onFileDrop = (event: React.DragEvent) => { + event.preventDefault() + + const files = event.dataTransfer.files + const file = files[0] + + handleSelectDeviceFile(file) + + setIsDragging(false) + } + + const handleDragEnter = (event: React.DragEvent) => { + event.preventDefault() + setIsDragging(true) + } + + const handleDragLeave = (event: React.DragEvent) => { + event.preventDefault() + setIsDragging(false) + } + + const onDragOver = (event: React.DragEvent) => { + event.preventDefault() + } + + const handleToggleSidebar = () => { + setShowSidebar(prevState => !prevState) + localStorage.setItem("showSidebar", String(!showSidebar)) + } + + return ( +
+ + +
+ {showSidebar && ( + { + setContentType(tabValue as ContentType) + router.replace(`${pathname}?tab=${tabValue}`) + }} + > + + + + + )} +
+ +
+ {isDragging ? ( +
+ {t("side.dropFileHere")} +
+ ) : ( + children + )} + + +
+
+ ) +} diff --git a/chatdesk-ui/components/ui/dialog.tsx b/chatdesk-ui/components/ui/dialog.tsx new file mode 100644 index 0000000..5c01896 --- /dev/null +++ b/chatdesk-ui/components/ui/dialog.tsx @@ -0,0 +1,121 @@ +"use client" + +import * as DialogPrimitive from "@radix-ui/react-dialog" +import * as React from "react" + +import { cn } from "@/lib/utils" + +const Dialog = DialogPrimitive.Root + +const DialogTrigger = DialogPrimitive.Trigger + +const DialogPortal = DialogPrimitive.Portal + +const DialogClose = DialogPrimitive.Close + +const DialogOverlay = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +DialogOverlay.displayName = DialogPrimitive.Overlay.displayName + +const DialogContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, children, ...props }, ref) => ( + + + + {children} + {/* + + Close + */} + + +)) +DialogContent.displayName = DialogPrimitive.Content.displayName + +const DialogHeader = ({ + className, + ...props +}: React.HTMLAttributes) => ( +
+) +DialogHeader.displayName = "DialogHeader" + +const DialogFooter = ({ + className, + ...props +}: React.HTMLAttributes) => ( +
+) +DialogFooter.displayName = "DialogFooter" + +const DialogTitle = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +DialogTitle.displayName = DialogPrimitive.Title.displayName + +const DialogDescription = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +DialogDescription.displayName = DialogPrimitive.Description.displayName + +export { + Dialog, + DialogClose, + DialogContent, + DialogDescription, + DialogFooter, + DialogHeader, + DialogOverlay, + DialogPortal, + DialogTitle, + DialogTrigger +} diff --git a/chatdesk-ui/components/ui/dropdown-menu.tsx b/chatdesk-ui/components/ui/dropdown-menu.tsx new file mode 100644 index 0000000..0519d65 --- /dev/null +++ b/chatdesk-ui/components/ui/dropdown-menu.tsx @@ -0,0 +1,200 @@ +"use client" + +import * as React from "react" +import * as DropdownMenuPrimitive from "@radix-ui/react-dropdown-menu" +import { Check, ChevronRight, Circle } from "lucide-react" + +import { cn } from "@/lib/utils" + +const DropdownMenu = DropdownMenuPrimitive.Root + +const DropdownMenuTrigger = DropdownMenuPrimitive.Trigger + +const DropdownMenuGroup = DropdownMenuPrimitive.Group + +const DropdownMenuPortal = DropdownMenuPrimitive.Portal + +const DropdownMenuSub = DropdownMenuPrimitive.Sub + +const DropdownMenuRadioGroup = DropdownMenuPrimitive.RadioGroup + +const DropdownMenuSubTrigger = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef & { + inset?: boolean + } +>(({ className, inset, children, ...props }, ref) => ( + + {children} + + +)) +DropdownMenuSubTrigger.displayName = + DropdownMenuPrimitive.SubTrigger.displayName + +const DropdownMenuSubContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +DropdownMenuSubContent.displayName = + DropdownMenuPrimitive.SubContent.displayName + +const DropdownMenuContent = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, sideOffset = 4, ...props }, ref) => ( + + + +)) +DropdownMenuContent.displayName = DropdownMenuPrimitive.Content.displayName + +const DropdownMenuItem = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef & { + inset?: boolean + } +>(({ className, inset, ...props }, ref) => ( + +)) +DropdownMenuItem.displayName = DropdownMenuPrimitive.Item.displayName + +const DropdownMenuCheckboxItem = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, children, checked, ...props }, ref) => ( + + + + + + + {children} + +)) +DropdownMenuCheckboxItem.displayName = + DropdownMenuPrimitive.CheckboxItem.displayName + +const DropdownMenuRadioItem = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, children, ...props }, ref) => ( + + + + + + + {children} + +)) +DropdownMenuRadioItem.displayName = DropdownMenuPrimitive.RadioItem.displayName + +const DropdownMenuLabel = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef & { + inset?: boolean + } +>(({ className, inset, ...props }, ref) => ( + +)) +DropdownMenuLabel.displayName = DropdownMenuPrimitive.Label.displayName + +const DropdownMenuSeparator = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + +)) +DropdownMenuSeparator.displayName = DropdownMenuPrimitive.Separator.displayName + +const DropdownMenuShortcut = ({ + className, + ...props +}: React.HTMLAttributes) => { + return ( + + ) +} +DropdownMenuShortcut.displayName = "DropdownMenuShortcut" + +export { + DropdownMenu, + DropdownMenuTrigger, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuCheckboxItem, + DropdownMenuRadioItem, + DropdownMenuLabel, + DropdownMenuSeparator, + DropdownMenuShortcut, + DropdownMenuGroup, + DropdownMenuPortal, + DropdownMenuSub, + DropdownMenuSubContent, + DropdownMenuSubTrigger, + DropdownMenuRadioGroup +} diff --git a/chatdesk-ui/components/ui/file-icon.tsx b/chatdesk-ui/components/ui/file-icon.tsx new file mode 100644 index 0000000..814dda3 --- /dev/null +++ b/chatdesk-ui/components/ui/file-icon.tsx @@ -0,0 +1,36 @@ +import { + IconFile, + IconFileText, + IconFileTypeCsv, + IconFileTypeDocx, + IconFileTypePdf, + IconJson, + IconMarkdown, + IconPhoto +} from "@tabler/icons-react" +import { FC } from "react" + +interface FileIconProps { + type: string + size?: number +} + +export const FileIcon: FC = ({ type, size = 32 }) => { + if (type.includes("image")) { + return + } else if (type.includes("pdf")) { + return + } else if (type.includes("csv")) { + return + } else if (type.includes("docx")) { + return + } else if (type.includes("plain")) { + return + } else if (type.includes("json")) { + return + } else if (type.includes("markdown")) { + return + } else { + return + } +} diff --git a/chatdesk-ui/components/ui/file-preview.tsx b/chatdesk-ui/components/ui/file-preview.tsx new file mode 100644 index 0000000..2bfbc70 --- /dev/null +++ b/chatdesk-ui/components/ui/file-preview.tsx @@ -0,0 +1,68 @@ +import { cn } from "@/lib/utils" +import { Tables } from "@/supabase/types" +import { ChatFile, MessageImage } from "@/types" +import { IconFileFilled } from "@tabler/icons-react" +import Image from "next/image" +import { FC } from "react" +import { DrawingCanvas } from "../utility/drawing-canvas" +import { Dialog, DialogContent } from "./dialog" + +interface FilePreviewProps { + type: "image" | "file" | "file_item" + item: ChatFile | MessageImage | Tables<"file_items"> + isOpen: boolean + onOpenChange: (isOpen: boolean) => void +} + +export const FilePreview: FC = ({ + type, + item, + isOpen, + onOpenChange +}) => { + return ( + + + {(() => { + if (type === "image") { + const imageItem = item as MessageImage + + return imageItem.file ? ( + + ) : ( + File image + ) + } else if (type === "file_item") { + const fileItem = item as Tables<"file_items"> + return ( +
+
{fileItem.content}
+
+ ) + } else if (type === "file") { + return ( +
+ +
+ ) + } + })()} +
+
+ ) +} diff --git a/chatdesk-ui/components/ui/form.tsx b/chatdesk-ui/components/ui/form.tsx new file mode 100644 index 0000000..38cb190 --- /dev/null +++ b/chatdesk-ui/components/ui/form.tsx @@ -0,0 +1,176 @@ +import * as React from "react" +import * as LabelPrimitive from "@radix-ui/react-label" +import { Slot } from "@radix-ui/react-slot" +import { + Controller, + ControllerProps, + FieldPath, + FieldValues, + FormProvider, + useFormContext +} from "react-hook-form" + +import { cn } from "@/lib/utils" +import { Label } from "@/components/ui/label" + +const Form = FormProvider + +type FormFieldContextValue< + TFieldValues extends FieldValues = FieldValues, + TName extends FieldPath = FieldPath +> = { + name: TName +} + +const FormFieldContext = React.createContext( + {} as FormFieldContextValue +) + +const FormField = < + TFieldValues extends FieldValues = FieldValues, + TName extends FieldPath = FieldPath +>({ + ...props +}: ControllerProps) => { + return ( + + + + ) +} + +const useFormField = () => { + const fieldContext = React.useContext(FormFieldContext) + const itemContext = React.useContext(FormItemContext) + const { getFieldState, formState } = useFormContext() + + const fieldState = getFieldState(fieldContext.name, formState) + + if (!fieldContext) { + throw new Error("useFormField should be used within ") + } + + const { id } = itemContext + + return { + id, + name: fieldContext.name, + formItemId: `${id}-form-item`, + formDescriptionId: `${id}-form-item-description`, + formMessageId: `${id}-form-item-message`, + ...fieldState + } +} + +type FormItemContextValue = { + id: string +} + +const FormItemContext = React.createContext( + {} as FormItemContextValue +) + +const FormItem = React.forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => { + const id = React.useId() + + return ( + +
+ + ) +}) +FormItem.displayName = "FormItem" + +const FormLabel = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => { + const { error, formItemId } = useFormField() + + return ( +