Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog/next_release/330.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement SQL transformations in worker
15 changes: 14 additions & 1 deletion syncmaster/worker/handlers/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,14 @@ def read(self) -> DataFrame:
columns=self._get_columns_filter_expressions(),
**reader_params,
)
return reader.run()
df = reader.run()

sql_query = self._get_sql_query()
if sql_query:
df.createOrReplaceTempView("source")
df = self.connection.spark.sql(sql_query)

return df

def write(self, df: DataFrame) -> None:
if self.transfer_dto.strategy.type == "incremental" and self.hwm and self.hwm.value:
Expand Down Expand Up @@ -110,6 +117,12 @@ def _get_columns_filter_expressions(self) -> list[str] | None:

return self._make_columns_filter_expressions(expressions)

def _get_sql_query(self) -> str | None:
for transformation in self.transfer_dto.transformations:
if transformation["type"] == "sql":
return transformation["query"]
return None

def _get_reading_options(self) -> dict:
options: dict[str, Any] = {}

Expand Down
6 changes: 6 additions & 0 deletions syncmaster/worker/handlers/file/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,9 @@ def _get_columns_filter_expressions(self) -> list[str] | None:
expressions.extend(transformation["filters"])

return self._make_columns_filter_expressions(expressions)

def _get_sql_query(self) -> str | None:
for transformation in self.transfer_dto.transformations:
if transformation["type"] == "sql":
return transformation["query"]
return None
5 changes: 5 additions & 0 deletions syncmaster/worker/handlers/file/local_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def read(self) -> DataFrame:
if columns_filter_expressions:
df = df.selectExpr(*columns_filter_expressions)

sql_query = self._get_sql_query()
if sql_query:
df.createOrReplaceTempView("source")
df = self.df_connection.spark.sql(sql_query)

return df

def write(self, df: DataFrame) -> None:
Expand Down
5 changes: 5 additions & 0 deletions syncmaster/worker/handlers/file/remote_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ def read(self) -> DataFrame:
if columns_filter_expressions:
df = df.selectExpr(*columns_filter_expressions)

sql_query = self._get_sql_query()
if sql_query:
df.createOrReplaceTempView("source")
df = self.df_connection.spark.sql(sql_query)

return df

def write(self, df: DataFrame) -> None:
Expand Down
5 changes: 5 additions & 0 deletions syncmaster/worker/handlers/file/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def read(self) -> DataFrame:
if columns_filter_expressions:
df = df.selectExpr(*columns_filter_expressions)

sql_query = self._get_sql_query()
if sql_query:
df.createOrReplaceTempView("source")
df = self.df_connection.spark.sql(sql_query)

return df

@slot
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pytest_plugins = [
"tests.test_unit.test_scheduler.scheduler_fixtures",
"tests.test_integration.test_scheduler.scheduler_fixtures",
"tests.test_integration.test_run_transfer.connection_fixtures",
]
3 changes: 0 additions & 3 deletions tests/test_integration/test_run_transfer/conftest.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
dataframe_rows_filter_transformations,
expected_dataframe_columns_filter,
expected_dataframe_rows_filter,
expected_sql_transformation,
file_metadata_filter_transformations,
sql_transformation,
)
from tests.test_integration.test_run_transfer.connection_fixtures.ftp_fixtures import (
ftp_connection,
Expand Down Expand Up @@ -151,6 +153,7 @@
"dataframe_rows_filter_transformations",
"expected_dataframe_columns_filter",
"expected_dataframe_rows_filter",
"expected_sql_transformation",
"file_format_flavor",
"file_metadata_filter_transformations",
"ftp_connection",
Expand Down Expand Up @@ -239,6 +242,7 @@
"sftp_for_worker",
"source_file_format",
"spark",
"sql_transformation",
"target_file_format",
"update_transfer_strategy",
"webdav_connection",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,19 @@ def file_metadata_filter_transformations():
],
},
]


@pytest.fixture
def sql_transformation():
return [
{
"type": "sql",
"query": "SELECT * FROM source WHERE NUMBER <= 20",
"dialect": "spark",
},
]


@pytest.fixture
def expected_sql_transformation():
return lambda df, source_type: df.filter(df.NUMBER <= 20)
61 changes: 61 additions & 0 deletions tests/test_integration/test_run_transfer/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,64 @@ async def test_run_transfer_postgres_to_s3_with_incremental_strategy(
df_with_increment = reader.run()
df_with_increment, init_df = cast_dataframe_types(df_with_increment, init_df)
assert df_with_increment.sort("id").collect() == init_df.sort("id").collect()


@pytest.mark.parametrize(
[
"target_file_format",
"file_format_flavor",
"strategy",
"transformations",
"expected_extension",
],
[
pytest.param(
("parquet", {}),
"without_compression",
lf("full_strategy"),
lf("sql_transformation"),
"parquet",
id="sql_transformation",
),
],
indirect=["target_file_format", "file_format_flavor"],
)
async def test_run_transfer_postgres_to_s3_with_sql_transformation(
group_owner: MockUser,
init_df: DataFrame,
client: AsyncClient,
s3_file_df_connection: SparkS3,
s3_file_connection: S3,
prepare_postgres,
prepare_s3,
postgres_to_s3: Transfer,
target_file_format,
file_format_flavor: str,
strategy,
transformations,
expected_sql_transformation,
expected_extension: str,
):
format_name, format = target_file_format
target_path = f"/target/{format_name}/{file_format_flavor}"
_, fill_with_data = prepare_postgres
fill_with_data(init_df)

init_df = expected_sql_transformation(init_df, "postgres")

await run_transfer_and_verify(client, group_owner, postgres_to_s3.id, target_auth="s3")

file_names = [file.name for file in s3_file_connection.list_dir(target_path) if file.is_file()]
verify_file_name_template(file_names, expected_extension)

reader = FileDFReader(
connection=s3_file_df_connection,
format=format,
source_path=target_path,
df_schema=init_df.schema,
options={},
)
df = reader.run()

df, init_df = cast_dataframe_types(df, init_df)
assert df.sort("id").collect() == init_df.sort("id").collect()
Loading