summary refs log tree commit diff
path: root/vendor/maunium.net/go/mautrix/statestore.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/maunium.net/go/mautrix/statestore.go')
-rw-r--r--vendor/maunium.net/go/mautrix/statestore.go344
1 files changed, 344 insertions, 0 deletions
diff --git a/vendor/maunium.net/go/mautrix/statestore.go b/vendor/maunium.net/go/mautrix/statestore.go
new file mode 100644
index 0000000..e728b88
--- /dev/null
+++ b/vendor/maunium.net/go/mautrix/statestore.go
@@ -0,0 +1,344 @@
+// Copyright (c) 2023 Tulir Asokan
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+package mautrix
+
+import (
+	"context"
+	"maps"
+	"sync"
+
+	"github.com/rs/zerolog"
+	"go.mau.fi/util/exerrors"
+
+	"maunium.net/go/mautrix/event"
+	"maunium.net/go/mautrix/id"
+)
+
+// StateStore is an interface for storing basic room state information.
+type StateStore interface {
+	IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool
+	IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool
+	IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool
+	GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error)
+	TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error)
+	SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error
+	SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error
+	IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error)
+	ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error
+	ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error
+
+	SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error
+	GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error)
+
+	HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error)
+	MarkMembersFetched(ctx context.Context, roomID id.RoomID) error
+	GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error)
+
+	SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error
+	IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error)
+
+	GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error)
+}
+
+type StateStoreUpdater interface {
+	UpdateState(ctx context.Context, evt *event.Event)
+}
+
+func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) {
+	if store == nil || evt == nil || evt.StateKey == nil {
+		return
+	}
+	if directUpdater, ok := store.(StateStoreUpdater); ok {
+		directUpdater.UpdateState(ctx, evt)
+		return
+	}
+	// We only care about events without a state key (power levels, encryption) or member events with state key
+	if evt.Type != event.StateMember && evt.GetStateKey() != "" {
+		return
+	}
+	var err error
+	switch content := evt.Content.Parsed.(type) {
+	case *event.MemberEventContent:
+		err = store.SetMember(ctx, evt.RoomID, id.UserID(evt.GetStateKey()), content)
+	case *event.PowerLevelsEventContent:
+		err = store.SetPowerLevels(ctx, evt.RoomID, content)
+	case *event.EncryptionEventContent:
+		err = store.SetEncryptionEvent(ctx, evt.RoomID, content)
+	default:
+		switch evt.Type {
+		case event.StateMember, event.StatePowerLevels, event.StateEncryption:
+			zerolog.Ctx(ctx).Warn().
+				Stringer("event_id", evt.ID).
+				Str("event_type", evt.Type.Type).
+				Type("content_type", evt.Content.Parsed).
+				Msg("Got known event type with unknown content type in UpdateStateStore")
+		}
+	}
+	if err != nil {
+		zerolog.Ctx(ctx).Warn().Err(err).
+			Stringer("event_id", evt.ID).
+			Str("event_type", evt.Type.Type).
+			Msg("Failed to update state store")
+	}
+}
+
+// StateStoreSyncHandler can be added as an event handler in the syncer to update the state store automatically.
+//
+//	client.Syncer.(mautrix.ExtensibleSyncer).OnEvent(client.StateStoreSyncHandler)
+//
+// DefaultSyncer.ParseEventContent must also be true for this to work (which it is by default).
+func (cli *Client) StateStoreSyncHandler(ctx context.Context, evt *event.Event) {
+	UpdateStateStore(ctx, cli.StateStore, evt)
+}
+
+type MemoryStateStore struct {
+	Registrations  map[id.UserID]bool                                    `json:"registrations"`
+	Members        map[id.RoomID]map[id.UserID]*event.MemberEventContent `json:"memberships"`
+	MembersFetched map[id.RoomID]bool                                    `json:"members_fetched"`
+	PowerLevels    map[id.RoomID]*event.PowerLevelsEventContent          `json:"power_levels"`
+	Encryption     map[id.RoomID]*event.EncryptionEventContent           `json:"encryption"`
+
+	registrationsLock sync.RWMutex
+	membersLock       sync.RWMutex
+	powerLevelsLock   sync.RWMutex
+	encryptionLock    sync.RWMutex
+}
+
+func NewMemoryStateStore() StateStore {
+	return &MemoryStateStore{
+		Registrations:  make(map[id.UserID]bool),
+		Members:        make(map[id.RoomID]map[id.UserID]*event.MemberEventContent),
+		MembersFetched: make(map[id.RoomID]bool),
+		PowerLevels:    make(map[id.RoomID]*event.PowerLevelsEventContent),
+		Encryption:     make(map[id.RoomID]*event.EncryptionEventContent),
+	}
+}
+
+func (store *MemoryStateStore) IsRegistered(_ context.Context, userID id.UserID) (bool, error) {
+	store.registrationsLock.RLock()
+	defer store.registrationsLock.RUnlock()
+	registered, ok := store.Registrations[userID]
+	return ok && registered, nil
+}
+
+func (store *MemoryStateStore) MarkRegistered(_ context.Context, userID id.UserID) error {
+	store.registrationsLock.Lock()
+	defer store.registrationsLock.Unlock()
+	store.Registrations[userID] = true
+	return nil
+}
+
+func (store *MemoryStateStore) GetRoomMembers(_ context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) {
+	store.membersLock.RLock()
+	members, ok := store.Members[roomID]
+	store.membersLock.RUnlock()
+	if !ok {
+		members = make(map[id.UserID]*event.MemberEventContent)
+		store.membersLock.Lock()
+		store.Members[roomID] = members
+		store.membersLock.Unlock()
+	}
+	return members, nil
+}
+
+func (store *MemoryStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) {
+	members, err := store.GetRoomMembers(ctx, roomID)
+	if err != nil {
+		return nil, err
+	}
+	ids := make([]id.UserID, 0, len(members))
+	for id := range members {
+		ids = append(ids, id)
+	}
+	return ids, nil
+}
+
+func (store *MemoryStateStore) GetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID) (event.Membership, error) {
+	return exerrors.Must(store.GetMember(ctx, roomID, userID)).Membership, nil
+}
+
+func (store *MemoryStateStore) GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) {
+	member, err := store.TryGetMember(ctx, roomID, userID)
+	if member == nil && err == nil {
+		member = &event.MemberEventContent{Membership: event.MembershipLeave}
+	}
+	return member, err
+}
+
+func (store *MemoryStateStore) IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error) {
+	// TODO implement?
+	return nil, nil
+}
+
+func (store *MemoryStateStore) TryGetMember(_ context.Context, roomID id.RoomID, userID id.UserID) (member *event.MemberEventContent, err error) {
+	store.membersLock.RLock()
+	defer store.membersLock.RUnlock()
+	members, membersOk := store.Members[roomID]
+	if !membersOk {
+		return
+	}
+	member = members[userID]
+	return
+}
+
+func (store *MemoryStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
+	return store.IsMembership(ctx, roomID, userID, "join")
+}
+
+func (store *MemoryStateStore) IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
+	return store.IsMembership(ctx, roomID, userID, "join", "invite")
+}
+
+func (store *MemoryStateStore) IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
+	membership := exerrors.Must(store.GetMembership(ctx, roomID, userID))
+	for _, allowedMembership := range allowedMemberships {
+		if allowedMembership == membership {
+			return true
+		}
+	}
+	return false
+}
+
+func (store *MemoryStateStore) SetMembership(_ context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error {
+	store.membersLock.Lock()
+	members, ok := store.Members[roomID]
+	if !ok {
+		members = map[id.UserID]*event.MemberEventContent{
+			userID: {Membership: membership},
+		}
+	} else {
+		member, ok := members[userID]
+		if !ok {
+			members[userID] = &event.MemberEventContent{Membership: membership}
+		} else {
+			member.Membership = membership
+			members[userID] = member
+		}
+	}
+	store.Members[roomID] = members
+	store.membersLock.Unlock()
+	return nil
+}
+
+func (store *MemoryStateStore) SetMember(_ context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error {
+	store.membersLock.Lock()
+	members, ok := store.Members[roomID]
+	if !ok {
+		members = map[id.UserID]*event.MemberEventContent{
+			userID: member,
+		}
+	} else {
+		members[userID] = member
+	}
+	store.Members[roomID] = members
+	store.membersLock.Unlock()
+	return nil
+}
+
+func (store *MemoryStateStore) ClearCachedMembers(_ context.Context, roomID id.RoomID, memberships ...event.Membership) error {
+	store.membersLock.Lock()
+	defer store.membersLock.Unlock()
+	members, ok := store.Members[roomID]
+	if !ok {
+		return nil
+	}
+	for userID, member := range members {
+		for _, membership := range memberships {
+			if membership == member.Membership {
+				delete(members, userID)
+				break
+			}
+		}
+	}
+	store.MembersFetched[roomID] = false
+	return nil
+}
+
+func (store *MemoryStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error) {
+	store.membersLock.RLock()
+	defer store.membersLock.RUnlock()
+	return store.MembersFetched[roomID], nil
+}
+
+func (store *MemoryStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error {
+	store.membersLock.Lock()
+	defer store.membersLock.Unlock()
+	store.MembersFetched[roomID] = true
+	return nil
+}
+
+func (store *MemoryStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error {
+	_ = store.ClearCachedMembers(ctx, roomID, onlyMemberships...)
+	for _, evt := range evts {
+		UpdateStateStore(ctx, store, evt)
+	}
+	if len(onlyMemberships) == 0 {
+		_ = store.MarkMembersFetched(ctx, roomID)
+	}
+	return nil
+}
+
+func (store *MemoryStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) {
+	store.membersLock.RLock()
+	defer store.membersLock.RUnlock()
+	return maps.Clone(store.Members[roomID]), nil
+}
+
+func (store *MemoryStateStore) SetPowerLevels(_ context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error {
+	store.powerLevelsLock.Lock()
+	store.PowerLevels[roomID] = levels
+	store.powerLevelsLock.Unlock()
+	return nil
+}
+
+func (store *MemoryStateStore) GetPowerLevels(_ context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) {
+	store.powerLevelsLock.RLock()
+	levels = store.PowerLevels[roomID]
+	store.powerLevelsLock.RUnlock()
+	return
+}
+
+func (store *MemoryStateStore) GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) {
+	return exerrors.Must(store.GetPowerLevels(ctx, roomID)).GetUserLevel(userID), nil
+}
+
+func (store *MemoryStateStore) GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) {
+	return exerrors.Must(store.GetPowerLevels(ctx, roomID)).GetEventLevel(eventType), nil
+}
+
+func (store *MemoryStateStore) HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) {
+	return exerrors.Must(store.GetPowerLevel(ctx, roomID, userID)) >= exerrors.Must(store.GetPowerLevelRequirement(ctx, roomID, eventType)), nil
+}
+
+func (store *MemoryStateStore) SetEncryptionEvent(_ context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
+	store.encryptionLock.Lock()
+	store.Encryption[roomID] = content
+	store.encryptionLock.Unlock()
+	return nil
+}
+
+func (store *MemoryStateStore) GetEncryptionEvent(_ context.Context, roomID id.RoomID) (*event.EncryptionEventContent, error) {
+	store.encryptionLock.RLock()
+	defer store.encryptionLock.RUnlock()
+	return store.Encryption[roomID], nil
+}
+
+func (store *MemoryStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) {
+	cfg, err := store.GetEncryptionEvent(ctx, roomID)
+	return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err
+}
+
+func (store *MemoryStateStore) FindSharedRooms(ctx context.Context, userID id.UserID) (rooms []id.RoomID, err error) {
+	store.membersLock.RLock()
+	defer store.membersLock.RUnlock()
+	for roomID, members := range store.Members {
+		if _, ok := members[userID]; ok {
+			rooms = append(rooms, roomID)
+		}
+	}
+	return rooms, nil
+}