chatdesk-ui/auth_v2.169.0/internal/models/sso.go

263 lines
6.5 KiB
Go

package models
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"reflect"
"strings"
"time"
"github.com/crewjam/saml"
"github.com/crewjam/saml/samlsp"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
"github.com/supabase/auth/internal/storage"
)
type SSOProvider struct {
ID uuid.UUID `db:"id" json:"id"`
SAMLProvider SAMLProvider `has_one:"saml_providers" fk_id:"sso_provider_id" json:"saml,omitempty"`
SSODomains []SSODomain `has_many:"sso_domains" fk_id:"sso_provider_id" json:"domains"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
}
func (p SSOProvider) TableName() string {
return "sso_providers"
}
func (p SSOProvider) Type() string {
return "saml"
}
type SAMLAttribute struct {
Name string `json:"name,omitempty"`
Names []string `json:"names,omitempty"`
Default interface{} `json:"default,omitempty"`
Array bool `json:"array,omitempty"`
}
type SAMLAttributeMapping struct {
Keys map[string]SAMLAttribute `json:"keys,omitempty"`
}
func (m *SAMLAttributeMapping) Equal(o *SAMLAttributeMapping) bool {
if m == o {
return true
}
if m == nil || o == nil {
return false
}
if m.Keys == nil && o.Keys == nil {
return true
}
if len(m.Keys) != len(o.Keys) {
return false
}
for mkey, mvalue := range m.Keys {
value, ok := o.Keys[mkey]
if !ok {
return false
}
if mvalue.Name != value.Name || len(mvalue.Names) != len(value.Names) {
return false
}
for i := 0; i < len(mvalue.Names); i += 1 {
if mvalue.Names[i] != value.Names[i] {
return false
}
}
if !reflect.DeepEqual(mvalue.Default, value.Default) {
return false
}
if mvalue.Array != value.Array {
return false
}
}
return true
}
func (m *SAMLAttributeMapping) Scan(src interface{}) error {
b, ok := src.([]byte)
if !ok {
return errors.New("scan source was not []byte")
}
err := json.Unmarshal(b, m)
if err != nil {
return err
}
return nil
}
func (m SAMLAttributeMapping) Value() (driver.Value, error) {
b, err := json.Marshal(m)
if err != nil {
return nil, err
}
return string(b), nil
}
type SAMLProvider struct {
ID uuid.UUID `db:"id" json:"-"`
SSOProvider *SSOProvider `belongs_to:"sso_providers" json:"-"`
SSOProviderID uuid.UUID `db:"sso_provider_id" json:"-"`
EntityID string `db:"entity_id" json:"entity_id"`
MetadataXML string `db:"metadata_xml" json:"metadata_xml,omitempty"`
MetadataURL *string `db:"metadata_url" json:"metadata_url,omitempty"`
AttributeMapping SAMLAttributeMapping `db:"attribute_mapping" json:"attribute_mapping,omitempty"`
NameIDFormat *string `db:"name_id_format" json:"name_id_format,omitempty"`
CreatedAt time.Time `db:"created_at" json:"-"`
UpdatedAt time.Time `db:"updated_at" json:"-"`
}
func (p SAMLProvider) TableName() string {
return "saml_providers"
}
func (p SAMLProvider) EntityDescriptor() (*saml.EntityDescriptor, error) {
return samlsp.ParseMetadata([]byte(p.MetadataXML))
}
type SSODomain struct {
ID uuid.UUID `db:"id" json:"-"`
SSOProvider *SSOProvider `belongs_to:"sso_providers" json:"-"`
SSOProviderID uuid.UUID `db:"sso_provider_id" json:"-"`
Domain string `db:"domain" json:"domain"`
CreatedAt time.Time `db:"created_at" json:"-"`
UpdatedAt time.Time `db:"updated_at" json:"-"`
}
func (d SSODomain) TableName() string {
return "sso_domains"
}
type SAMLRelayState struct {
ID uuid.UUID `db:"id"`
SSOProviderID uuid.UUID `db:"sso_provider_id"`
RequestID string `db:"request_id"`
ForEmail *string `db:"for_email"`
RedirectTo string `db:"redirect_to"`
CreatedAt time.Time `db:"created_at" json:"-"`
UpdatedAt time.Time `db:"updated_at" json:"-"`
FlowStateID *uuid.UUID `db:"flow_state_id" json:"flow_state_id,omitempty"`
FlowState *FlowState `db:"-" json:"flow_state,omitempty" belongs_to:"flow_state"`
}
func (s SAMLRelayState) TableName() string {
return "saml_relay_states"
}
func FindSAMLProviderByEntityID(tx *storage.Connection, entityId string) (*SSOProvider, error) {
var samlProvider SAMLProvider
if err := tx.Q().Where("entity_id = ?", entityId).First(&samlProvider); err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, SSOProviderNotFoundError{}
}
return nil, errors.Wrap(err, "error finding SAML SSO provider by EntityID")
}
var ssoProvider SSOProvider
if err := tx.Eager().Q().Where("id = ?", samlProvider.SSOProviderID).First(&ssoProvider); err != nil {
return nil, errors.Wrap(err, "error finding SAML SSO provider by ID (via EntityID)")
}
return &ssoProvider, nil
}
func FindSSOProviderByID(tx *storage.Connection, id uuid.UUID) (*SSOProvider, error) {
var ssoProvider SSOProvider
if err := tx.Eager().Q().Where("id = ?", id).First(&ssoProvider); err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, SSOProviderNotFoundError{}
}
return nil, errors.Wrap(err, "error finding SAML SSO provider by ID")
}
return &ssoProvider, nil
}
func FindSSOProviderForEmailAddress(tx *storage.Connection, emailAddress string) (*SSOProvider, error) {
parts := strings.Split(emailAddress, "@")
emailDomain := strings.ToLower(parts[1])
return FindSSOProviderByDomain(tx, emailDomain)
}
func FindSSOProviderByDomain(tx *storage.Connection, domain string) (*SSOProvider, error) {
var ssoDomain SSODomain
if err := tx.Q().Where("domain = ?", domain).First(&ssoDomain); err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, SSOProviderNotFoundError{}
}
return nil, errors.Wrap(err, "error finding SAML SSO domain")
}
var ssoProvider SSOProvider
if err := tx.Eager().Q().Where("id = ?", ssoDomain.SSOProviderID).First(&ssoProvider); err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, SSOProviderNotFoundError{}
}
return nil, errors.Wrap(err, "error finding SAML SSO provider by ID (via domain)")
}
return &ssoProvider, nil
}
func FindAllSAMLProviders(tx *storage.Connection) ([]SSOProvider, error) {
var providers []SSOProvider
if err := tx.Eager().All(&providers); err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, nil
}
return nil, errors.Wrap(err, "error loading all SAML SSO providers")
}
return providers, nil
}
func FindSAMLRelayStateByID(tx *storage.Connection, id uuid.UUID) (*SAMLRelayState, error) {
var state SAMLRelayState
if err := tx.Eager().Q().Where("id = ?", id).First(&state); err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, SAMLRelayStateNotFoundError{}
}
return nil, errors.Wrap(err, "error loading SAML Relay State")
}
return &state, nil
}