package models import ( "database/sql" "database/sql/driver" "encoding/json" "reflect" "strings" "time" "github.com/crewjam/saml" "github.com/crewjam/saml/samlsp" "github.com/gofrs/uuid" "github.com/pkg/errors" "github.com/supabase/auth/internal/storage" ) type SSOProvider struct { ID uuid.UUID `db:"id" json:"id"` SAMLProvider SAMLProvider `has_one:"saml_providers" fk_id:"sso_provider_id" json:"saml,omitempty"` SSODomains []SSODomain `has_many:"sso_domains" fk_id:"sso_provider_id" json:"domains"` CreatedAt time.Time `db:"created_at" json:"created_at"` UpdatedAt time.Time `db:"updated_at" json:"updated_at"` } func (p SSOProvider) TableName() string { return "sso_providers" } func (p SSOProvider) Type() string { return "saml" } type SAMLAttribute struct { Name string `json:"name,omitempty"` Names []string `json:"names,omitempty"` Default interface{} `json:"default,omitempty"` Array bool `json:"array,omitempty"` } type SAMLAttributeMapping struct { Keys map[string]SAMLAttribute `json:"keys,omitempty"` } func (m *SAMLAttributeMapping) Equal(o *SAMLAttributeMapping) bool { if m == o { return true } if m == nil || o == nil { return false } if m.Keys == nil && o.Keys == nil { return true } if len(m.Keys) != len(o.Keys) { return false } for mkey, mvalue := range m.Keys { value, ok := o.Keys[mkey] if !ok { return false } if mvalue.Name != value.Name || len(mvalue.Names) != len(value.Names) { return false } for i := 0; i < len(mvalue.Names); i += 1 { if mvalue.Names[i] != value.Names[i] { return false } } if !reflect.DeepEqual(mvalue.Default, value.Default) { return false } if mvalue.Array != value.Array { return false } } return true } func (m *SAMLAttributeMapping) Scan(src interface{}) error { b, ok := src.([]byte) if !ok { return errors.New("scan source was not []byte") } err := json.Unmarshal(b, m) if err != nil { return err } return nil } func (m SAMLAttributeMapping) Value() (driver.Value, error) { b, err := json.Marshal(m) if err != nil { return nil, err } return string(b), nil } type SAMLProvider struct { ID uuid.UUID `db:"id" json:"-"` SSOProvider *SSOProvider `belongs_to:"sso_providers" json:"-"` SSOProviderID uuid.UUID `db:"sso_provider_id" json:"-"` EntityID string `db:"entity_id" json:"entity_id"` MetadataXML string `db:"metadata_xml" json:"metadata_xml,omitempty"` MetadataURL *string `db:"metadata_url" json:"metadata_url,omitempty"` AttributeMapping SAMLAttributeMapping `db:"attribute_mapping" json:"attribute_mapping,omitempty"` NameIDFormat *string `db:"name_id_format" json:"name_id_format,omitempty"` CreatedAt time.Time `db:"created_at" json:"-"` UpdatedAt time.Time `db:"updated_at" json:"-"` } func (p SAMLProvider) TableName() string { return "saml_providers" } func (p SAMLProvider) EntityDescriptor() (*saml.EntityDescriptor, error) { return samlsp.ParseMetadata([]byte(p.MetadataXML)) } type SSODomain struct { ID uuid.UUID `db:"id" json:"-"` SSOProvider *SSOProvider `belongs_to:"sso_providers" json:"-"` SSOProviderID uuid.UUID `db:"sso_provider_id" json:"-"` Domain string `db:"domain" json:"domain"` CreatedAt time.Time `db:"created_at" json:"-"` UpdatedAt time.Time `db:"updated_at" json:"-"` } func (d SSODomain) TableName() string { return "sso_domains" } type SAMLRelayState struct { ID uuid.UUID `db:"id"` SSOProviderID uuid.UUID `db:"sso_provider_id"` RequestID string `db:"request_id"` ForEmail *string `db:"for_email"` RedirectTo string `db:"redirect_to"` CreatedAt time.Time `db:"created_at" json:"-"` UpdatedAt time.Time `db:"updated_at" json:"-"` FlowStateID *uuid.UUID `db:"flow_state_id" json:"flow_state_id,omitempty"` FlowState *FlowState `db:"-" json:"flow_state,omitempty" belongs_to:"flow_state"` } func (s SAMLRelayState) TableName() string { return "saml_relay_states" } func FindSAMLProviderByEntityID(tx *storage.Connection, entityId string) (*SSOProvider, error) { var samlProvider SAMLProvider if err := tx.Q().Where("entity_id = ?", entityId).First(&samlProvider); err != nil { if errors.Cause(err) == sql.ErrNoRows { return nil, SSOProviderNotFoundError{} } return nil, errors.Wrap(err, "error finding SAML SSO provider by EntityID") } var ssoProvider SSOProvider if err := tx.Eager().Q().Where("id = ?", samlProvider.SSOProviderID).First(&ssoProvider); err != nil { return nil, errors.Wrap(err, "error finding SAML SSO provider by ID (via EntityID)") } return &ssoProvider, nil } func FindSSOProviderByID(tx *storage.Connection, id uuid.UUID) (*SSOProvider, error) { var ssoProvider SSOProvider if err := tx.Eager().Q().Where("id = ?", id).First(&ssoProvider); err != nil { if errors.Cause(err) == sql.ErrNoRows { return nil, SSOProviderNotFoundError{} } return nil, errors.Wrap(err, "error finding SAML SSO provider by ID") } return &ssoProvider, nil } func FindSSOProviderForEmailAddress(tx *storage.Connection, emailAddress string) (*SSOProvider, error) { parts := strings.Split(emailAddress, "@") emailDomain := strings.ToLower(parts[1]) return FindSSOProviderByDomain(tx, emailDomain) } func FindSSOProviderByDomain(tx *storage.Connection, domain string) (*SSOProvider, error) { var ssoDomain SSODomain if err := tx.Q().Where("domain = ?", domain).First(&ssoDomain); err != nil { if errors.Cause(err) == sql.ErrNoRows { return nil, SSOProviderNotFoundError{} } return nil, errors.Wrap(err, "error finding SAML SSO domain") } var ssoProvider SSOProvider if err := tx.Eager().Q().Where("id = ?", ssoDomain.SSOProviderID).First(&ssoProvider); err != nil { if errors.Cause(err) == sql.ErrNoRows { return nil, SSOProviderNotFoundError{} } return nil, errors.Wrap(err, "error finding SAML SSO provider by ID (via domain)") } return &ssoProvider, nil } func FindAllSAMLProviders(tx *storage.Connection) ([]SSOProvider, error) { var providers []SSOProvider if err := tx.Eager().All(&providers); err != nil { if errors.Cause(err) == sql.ErrNoRows { return nil, nil } return nil, errors.Wrap(err, "error loading all SAML SSO providers") } return providers, nil } func FindSAMLRelayStateByID(tx *storage.Connection, id uuid.UUID) (*SAMLRelayState, error) { var state SAMLRelayState if err := tx.Eager().Q().Where("id = ?", id).First(&state); err != nil { if errors.Cause(err) == sql.ErrNoRows { return nil, SAMLRelayStateNotFoundError{} } return nil, errors.Wrap(err, "error loading SAML Relay State") } return &state, nil }