Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions src/workos/directory_sync.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Protocol
from typing import Any, Dict, Optional, Protocol, Union

from workos.types.directory_sync.list_filters import (
DirectoryGroupListFilters,
Expand Down Expand Up @@ -32,6 +32,24 @@
Directory, DirectoryListFilters, ListMetadata
]

# Mapping from SDK parameter names to API parameter names
PARAM_KEY_MAPPING = {
"directory_id": "directory",
"group_id": "group",
"user_id": "user",
}


def _prepare_request_params(
list_params: Union[DirectoryUserListFilters, DirectoryGroupListFilters],
) -> Dict[str, Any]:
"""Convert list_params to API request params by renaming keys."""
request_params: Dict[str, Any] = dict(list_params)
for sdk_key, api_key in PARAM_KEY_MAPPING.items():
if sdk_key in request_params:
request_params[api_key] = request_params.pop(sdk_key)
return request_params


class DirectorySyncModule(Protocol):
"""Offers methods through the WorkOS Directory Sync service."""
Expand Down Expand Up @@ -191,7 +209,7 @@ def list_users(
response = self._http_client.request(
"directory_users",
method=REQUEST_METHOD_GET,
params=list_params,
params=_prepare_request_params(list_params),
)

return WorkOSListResource(
Expand Down Expand Up @@ -225,7 +243,7 @@ def list_groups(
response = self._http_client.request(
"directory_groups",
method=REQUEST_METHOD_GET,
params=list_params,
params=_prepare_request_params(list_params),
)

return WorkOSListResource[
Expand Down Expand Up @@ -329,7 +347,7 @@ async def list_users(
response = await self._http_client.request(
"directory_users",
method=REQUEST_METHOD_GET,
params=list_params,
params=_prepare_request_params(list_params),
)

return WorkOSListResource(
Expand All @@ -354,6 +372,7 @@ async def list_groups(
"after": after,
"order": order,
}

if user_id is not None:
list_params["user_id"] = user_id
if directory_id is not None:
Expand All @@ -362,7 +381,7 @@ async def list_groups(
response = await self._http_client.request(
"directory_groups",
method=REQUEST_METHOD_GET,
params=list_params,
params=_prepare_request_params(list_params),
)

return WorkOSListResource[
Expand Down
64 changes: 59 additions & 5 deletions tests/test_directory_sync.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Union

import pytest
from workos.directory_sync import AsyncDirectorySync, DirectorySync

from tests.types.test_auto_pagination_function import TestAutoPaginationFunction
from tests.utils.fixtures.mock_directory import (
Expand All @@ -13,6 +12,11 @@
from tests.utils.fixtures.mock_directory_user import MockDirectoryUser
from tests.utils.list_resource import list_data_to_dicts, list_response_of
from tests.utils.syncify import syncify
from workos.directory_sync import (
_prepare_request_params,
AsyncDirectorySync,
DirectorySync,
)


def api_directory_to_sdk(directory):
Expand Down Expand Up @@ -145,7 +149,7 @@ def test_list_users_with_directory(
assert request_kwargs["url"].endswith("/directory_users")
assert request_kwargs["method"] == "get"
assert request_kwargs["params"] == {
"directory_id": "directory_id",
"directory": "directory_id",
"limit": 10,
"order": "desc",
}
Expand All @@ -163,7 +167,7 @@ def test_list_users_with_group(
assert request_kwargs["url"].endswith("/directory_users")
assert request_kwargs["method"] == "get"
assert request_kwargs["params"] == {
"group_id": "directory_grp_id",
"group": "directory_grp_id",
"limit": 10,
"order": "desc",
}
Expand All @@ -181,7 +185,7 @@ def test_list_groups_with_directory(
assert request_kwargs["url"].endswith("/directory_groups")
assert request_kwargs["method"] == "get"
assert request_kwargs["params"] == {
"directory_id": "directory_id",
"directory": "directory_id",
"limit": 10,
"order": "desc",
}
Expand All @@ -199,7 +203,7 @@ def test_list_groups_with_user(
assert request_kwargs["url"].endswith("/directory_groups")
assert request_kwargs["method"] == "get"
assert request_kwargs["params"] == {
"user_id": "directory_user_id",
"user": "directory_user_id",
"limit": 10,
"order": "desc",
}
Expand Down Expand Up @@ -371,3 +375,53 @@ def test_directory_user_groups_auto_pagination(
list_function=self.directory_sync.list_groups,
expected_all_page_data=mock_directory_groups_multiple_data_pages,
)


class TestPrepareRequestParams:
"""Tests for SDK-to-API parameter name translation.

The SDK uses Pythonic parameter names (directory_id, group_id, user_id)
but the WorkOS API expects shorter names (directory, group, user).
The _prepare_request_params function handles this translation.

See: https://github.com/workos/workos-python/issues/511
See: https://github.com/workos/workos-python/issues/519
"""

def test_translates_directory_id_to_directory(self):
params = {"directory_id": "dir_123", "limit": 10}
result = _prepare_request_params(params)
assert "directory" in result
assert "directory_id" not in result
assert result["directory"] == "dir_123"

def test_translates_group_id_to_group(self):
params = {"group_id": "grp_123", "limit": 10}
result = _prepare_request_params(params)
assert "group" in result
assert "group_id" not in result
assert result["group"] == "grp_123"

def test_translates_user_id_to_user(self):
params = {"user_id": "usr_123", "limit": 10}
result = _prepare_request_params(params)
assert "user" in result
assert "user_id" not in result
assert result["user"] == "usr_123"

def test_preserves_non_id_params(self):
params = {
"directory_id": "dir_123",
"limit": 10,
"order": "desc",
"after": "cursor",
}
result = _prepare_request_params(params)
assert result["limit"] == 10
assert result["order"] == "desc"
assert result["after"] == "cursor"

def test_handles_empty_params(self):
params = {"limit": 10, "order": "desc"}
result = _prepare_request_params(params)
assert result == {"limit": 10, "order": "desc"}