Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions app/ldap_protocol/ldap_requests/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ async def handle( # noqa: C901
select(Directory)
.options(
joinedload(qa(Directory.user)),
joinedload(qa(Directory.entity_type)),
selectinload(qa(Directory.groups)).selectinload(
qa(Group.directory),
),
Expand Down
152 changes: 123 additions & 29 deletions app/ldap_protocol/ldap_requests/modify_dn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from sqlalchemy import delete, func, select, text, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload, selectinload

from entities import AccessControlEntry, Attribute, Directory
Expand All @@ -18,7 +19,13 @@
INVALID_ACCESS_RESPONSE,
ModifyDNResponse,
)
from ldap_protocol.objects import ProtocolRequests
from ldap_protocol.objects import (
Changes,
Operation,
PartialAttribute,
ProtocolRequests,
)
from ldap_protocol.roles.access_manager import AccessManager
from ldap_protocol.utils.queries import get_filter_from_path, validate_entry
from repo.pg.tables import (
ace_directory_memberships_table,
Expand Down Expand Up @@ -85,7 +92,75 @@ def from_data(cls, data: list[ASN1Row]) -> "ModifyDNRequest":
new_superior=None if len(data) < 4 else data[3].value,
)

async def handle(
def _is_move_to_new_superior(self, directory: Directory) -> bool:
return bool(
self.new_superior
and directory.parent
and self.new_superior != directory.parent.path_dn,
)

def _can_rename(
self,
access_manager: AccessManager,
directory: Directory,
name: str,
) -> bool:
return access_manager.check_modify_access(
changes=[
Changes(
operation=Operation.REPLACE,
modification=PartialAttribute(type="name", vals=[name]),
),
],
aces=directory.access_control_entries,
entity_type_id=directory.entity_type_id,
)

async def _delete_old_inherited_aces(
self,
session: AsyncSession,
directory: Directory,
old_depth: int,
) -> None:
old_inherited_aces_query = (
select(qa(AccessControlEntry.id))
.options(selectinload(qa(AccessControlEntry.directories)))
.where(
qa(AccessControlEntry.directories).contains(directory),
qa(AccessControlEntry.depth) != old_depth,
)
)
await session.execute(
delete(ace_directory_memberships_table)
.filter_by(
directory_id=directory.id,
)
.where(
ace_directory_memberships_table.c.access_control_entry_id.in_(
old_inherited_aces_query,
),
),
)

async def _update_explicit_aces(
self,
session: AsyncSession,
directory: Directory,
old_depth: int,
) -> None:
explicit_aces_query = (
select(AccessControlEntry)
.options(selectinload(qa(AccessControlEntry.directories)))
.where(
qa(AccessControlEntry.directories).contains(directory),
qa(AccessControlEntry.depth) == old_depth,
)
)
for ace in await session.scalars(explicit_aces_query):
ace.path = directory.path_dn
ace.depth = directory.depth

async def handle( # noqa: C901
self,
ctx: LDAPModifyDNRequestContext,
) -> AsyncGenerator[ModifyDNResponse, None]:
Expand Down Expand Up @@ -122,8 +197,8 @@ async def handle(
query = ctx.access_manager.mutate_query_with_ace_load(
user_role_ids=ctx.ldap_session.user.role_ids,
query=query,
ace_types=[AceType.DELETE],
require_attribute_type_null=True,
ace_types=[AceType.DELETE, AceType.WRITE],
load_attribute_type=True,
)

directory = await ctx.session.scalar(query)
Expand All @@ -142,8 +217,20 @@ async def handle(
)
return

old_name = directory.name
new_dn, new_name = self.newrdn.split("=")
is_move_to_new_superior = self._is_move_to_new_superior(directory)

if not is_move_to_new_superior and not self._can_rename(
ctx.access_manager,
directory,
new_name,
):
yield ModifyDNResponse(
result_code=LDAPCodes.INSUFFICIENT_ACCESS_RIGHTS,
)
return

old_name = directory.name
directory.name = new_name

old_path = directory.path
Expand All @@ -166,13 +253,29 @@ async def handle(
)
return

if (
self.new_superior
and directory.parent
and self.new_superior != directory.parent.path_dn
):
if is_move_to_new_superior:
delete_aces = [
ace
for ace in directory.access_control_entries
if (
ace.ace_type == AceType.DELETE
and ace.attribute_type is None
)
]

can_delete = ctx.access_manager.check_entity_level_access(
aces=delete_aces,
entity_type_id=directory.entity_type_id,
)

if not can_delete:
yield ModifyDNResponse(
result_code=LDAPCodes.INSUFFICIENT_ACCESS_RIGHTS,
)
return

new_sup_query = select(Directory).filter(
get_filter_from_path(self.new_superior),
get_filter_from_path(self.new_superior), # type: ignore
)
new_sup_query = ctx.access_manager.mutate_query_with_ace_load(
user_role_ids=ctx.ldap_session.user.role_ids,
Expand Down Expand Up @@ -203,11 +306,11 @@ async def handle(

try:
await ctx.session.flush()
await ctx.session.execute(
delete(ace_directory_memberships_table)
.filter_by(directory_id=directory.id),
) # fmt: skip

await self._delete_old_inherited_aces(
ctx.session,
directory=directory,
old_depth=old_depth,
)
await ctx.role_use_case.inherit_parent_aces(
parent_directory=directory.parent,
directory=directory,
Expand Down Expand Up @@ -265,20 +368,11 @@ async def handle(
)
await ctx.session.flush()

explicit_aces_query = (
select(AccessControlEntry)
.options(selectinload(qa(AccessControlEntry.directories)))
.where(
qa(AccessControlEntry.directories).any(
qa(Directory.id) == directory.id,
),
qa(AccessControlEntry.depth) == old_depth,
)
await self._update_explicit_aces(
ctx.session,
directory,
old_depth,
)
for ace in await ctx.session.scalars(explicit_aces_query):
ace.directories.append(directory)
ace.path = directory.path_dn
ace.depth = directory.depth

await ctx.session.flush()

Expand Down
15 changes: 11 additions & 4 deletions app/ldap_protocol/roles/access_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,12 +304,19 @@ def mutate_query_with_ace_load(
null attribute_type_id
:return: mutated query with access control entries loaded
"""
selectin_loader = selectinload(
base_loader = selectinload(
qa(Directory.access_control_entries),
)

loader_options = [
base_loader.joinedload(qa(AccessControlEntry.entity_type)),
]

if load_attribute_type:
selectin_loader = selectin_loader.joinedload(
qa(AccessControlEntry.attribute_type),
loader_options.append(
base_loader.joinedload(
qa(AccessControlEntry.attribute_type),
),
)

criteria_conditions = [
Expand All @@ -331,7 +338,7 @@ def mutate_query_with_ace_load(
)

return query.options(
selectin_loader,
*loader_options,
with_loader_criteria(
AccessControlEntry,
and_(*criteria_conditions),
Expand Down
Loading