From e3c69f169669e5175b923409f51a347393e26c3d Mon Sep 17 00:00:00 2001 From: Ilyas Gasanov Date: Tue, 3 Feb 2026 12:05:20 +0300 Subject: [PATCH] [DOP-31721] Implement SQL transformations --- docs/changelog/next_release/330.feature.rst | 1 + syncmaster/worker/handlers/db/base.py | 15 ++++- syncmaster/worker/handlers/file/base.py | 6 ++ syncmaster/worker/handlers/file/local_df.py | 5 ++ syncmaster/worker/handlers/file/remote_df.py | 5 ++ syncmaster/worker/handlers/file/s3.py | 5 ++ .../{test_scheduler => }/conftest.py | 1 + .../test_run_transfer/conftest.py | 3 - .../connection_fixtures/__init__.py | 4 ++ .../connection_fixtures/filters_fixtures.py | 16 +++++ .../test_run_transfer/test_s3.py | 61 +++++++++++++++++++ 11 files changed, 118 insertions(+), 4 deletions(-) create mode 100644 docs/changelog/next_release/330.feature.rst rename tests/test_integration/{test_scheduler => }/conftest.py (67%) delete mode 100644 tests/test_integration/test_run_transfer/conftest.py diff --git a/docs/changelog/next_release/330.feature.rst b/docs/changelog/next_release/330.feature.rst new file mode 100644 index 00000000..76fe1448 --- /dev/null +++ b/docs/changelog/next_release/330.feature.rst @@ -0,0 +1 @@ +Implement SQL transformations in worker \ No newline at end of file diff --git a/syncmaster/worker/handlers/db/base.py b/syncmaster/worker/handlers/db/base.py index 10e77ab8..bec26cae 100644 --- a/syncmaster/worker/handlers/db/base.py +++ b/syncmaster/worker/handlers/db/base.py @@ -58,7 +58,14 @@ def read(self) -> DataFrame: columns=self._get_columns_filter_expressions(), **reader_params, ) - return reader.run() + df = reader.run() + + sql_query = self._get_sql_query() + if sql_query: + df.createOrReplaceTempView("source") + df = self.connection.spark.sql(sql_query) + + return df def write(self, df: DataFrame) -> None: if self.transfer_dto.strategy.type == "incremental" and self.hwm and self.hwm.value: @@ -110,6 +117,12 @@ def _get_columns_filter_expressions(self) -> list[str] | None: return self._make_columns_filter_expressions(expressions) + def _get_sql_query(self) -> str | None: + for transformation in self.transfer_dto.transformations: + if transformation["type"] == "sql": + return transformation["query"] + return None + def _get_reading_options(self) -> dict: options: dict[str, Any] = {} diff --git a/syncmaster/worker/handlers/file/base.py b/syncmaster/worker/handlers/file/base.py index f6787f66..5effcc5c 100644 --- a/syncmaster/worker/handlers/file/base.py +++ b/syncmaster/worker/handlers/file/base.py @@ -135,3 +135,9 @@ def _get_columns_filter_expressions(self) -> list[str] | None: expressions.extend(transformation["filters"]) return self._make_columns_filter_expressions(expressions) + + def _get_sql_query(self) -> str | None: + for transformation in self.transfer_dto.transformations: + if transformation["type"] == "sql": + return transformation["query"] + return None diff --git a/syncmaster/worker/handlers/file/local_df.py b/syncmaster/worker/handlers/file/local_df.py index b0799458..9285b7e8 100644 --- a/syncmaster/worker/handlers/file/local_df.py +++ b/syncmaster/worker/handlers/file/local_df.py @@ -58,6 +58,11 @@ def read(self) -> DataFrame: if columns_filter_expressions: df = df.selectExpr(*columns_filter_expressions) + sql_query = self._get_sql_query() + if sql_query: + df.createOrReplaceTempView("source") + df = self.df_connection.spark.sql(sql_query) + return df def write(self, df: DataFrame) -> None: diff --git a/syncmaster/worker/handlers/file/remote_df.py b/syncmaster/worker/handlers/file/remote_df.py index 2191c12f..0695a911 100644 --- a/syncmaster/worker/handlers/file/remote_df.py +++ b/syncmaster/worker/handlers/file/remote_df.py @@ -35,6 +35,11 @@ def read(self) -> DataFrame: if columns_filter_expressions: df = df.selectExpr(*columns_filter_expressions) + sql_query = self._get_sql_query() + if sql_query: + df.createOrReplaceTempView("source") + df = self.df_connection.spark.sql(sql_query) + return df def write(self, df: DataFrame) -> None: diff --git a/syncmaster/worker/handlers/file/s3.py b/syncmaster/worker/handlers/file/s3.py index d4f592dc..8d305a97 100644 --- a/syncmaster/worker/handlers/file/s3.py +++ b/syncmaster/worker/handlers/file/s3.py @@ -71,6 +71,11 @@ def read(self) -> DataFrame: if columns_filter_expressions: df = df.selectExpr(*columns_filter_expressions) + sql_query = self._get_sql_query() + if sql_query: + df.createOrReplaceTempView("source") + df = self.df_connection.spark.sql(sql_query) + return df @slot diff --git a/tests/test_integration/test_scheduler/conftest.py b/tests/test_integration/conftest.py similarity index 67% rename from tests/test_integration/test_scheduler/conftest.py rename to tests/test_integration/conftest.py index b094ae35..85087943 100644 --- a/tests/test_integration/test_scheduler/conftest.py +++ b/tests/test_integration/conftest.py @@ -1,4 +1,5 @@ pytest_plugins = [ "tests.test_unit.test_scheduler.scheduler_fixtures", "tests.test_integration.test_scheduler.scheduler_fixtures", + "tests.test_integration.test_run_transfer.connection_fixtures", ] diff --git a/tests/test_integration/test_run_transfer/conftest.py b/tests/test_integration/test_run_transfer/conftest.py deleted file mode 100644 index 7b9ca128..00000000 --- a/tests/test_integration/test_run_transfer/conftest.py +++ /dev/null @@ -1,3 +0,0 @@ -pytest_plugins = [ - "tests.test_integration.test_run_transfer.connection_fixtures", -] diff --git a/tests/test_integration/test_run_transfer/connection_fixtures/__init__.py b/tests/test_integration/test_run_transfer/connection_fixtures/__init__.py index 63717418..af033c23 100644 --- a/tests/test_integration/test_run_transfer/connection_fixtures/__init__.py +++ b/tests/test_integration/test_run_transfer/connection_fixtures/__init__.py @@ -24,7 +24,9 @@ dataframe_rows_filter_transformations, expected_dataframe_columns_filter, expected_dataframe_rows_filter, + expected_sql_transformation, file_metadata_filter_transformations, + sql_transformation, ) from tests.test_integration.test_run_transfer.connection_fixtures.ftp_fixtures import ( ftp_connection, @@ -151,6 +153,7 @@ "dataframe_rows_filter_transformations", "expected_dataframe_columns_filter", "expected_dataframe_rows_filter", + "expected_sql_transformation", "file_format_flavor", "file_metadata_filter_transformations", "ftp_connection", @@ -239,6 +242,7 @@ "sftp_for_worker", "source_file_format", "spark", + "sql_transformation", "target_file_format", "update_transfer_strategy", "webdav_connection", diff --git a/tests/test_integration/test_run_transfer/connection_fixtures/filters_fixtures.py b/tests/test_integration/test_run_transfer/connection_fixtures/filters_fixtures.py index 62dd00a2..6d973ba0 100644 --- a/tests/test_integration/test_run_transfer/connection_fixtures/filters_fixtures.py +++ b/tests/test_integration/test_run_transfer/connection_fixtures/filters_fixtures.py @@ -133,3 +133,19 @@ def file_metadata_filter_transformations(): ], }, ] + + +@pytest.fixture +def sql_transformation(): + return [ + { + "type": "sql", + "query": "SELECT * FROM source WHERE NUMBER <= 20", + "dialect": "spark", + }, + ] + + +@pytest.fixture +def expected_sql_transformation(): + return lambda df, source_type: df.filter(df.NUMBER <= 20) diff --git a/tests/test_integration/test_run_transfer/test_s3.py b/tests/test_integration/test_run_transfer/test_s3.py index 0e8c8893..514bd338 100644 --- a/tests/test_integration/test_run_transfer/test_s3.py +++ b/tests/test_integration/test_run_transfer/test_s3.py @@ -381,3 +381,64 @@ async def test_run_transfer_postgres_to_s3_with_incremental_strategy( df_with_increment = reader.run() df_with_increment, init_df = cast_dataframe_types(df_with_increment, init_df) assert df_with_increment.sort("id").collect() == init_df.sort("id").collect() + + +@pytest.mark.parametrize( + [ + "target_file_format", + "file_format_flavor", + "strategy", + "transformations", + "expected_extension", + ], + [ + pytest.param( + ("parquet", {}), + "without_compression", + lf("full_strategy"), + lf("sql_transformation"), + "parquet", + id="sql_transformation", + ), + ], + indirect=["target_file_format", "file_format_flavor"], +) +async def test_run_transfer_postgres_to_s3_with_sql_transformation( + group_owner: MockUser, + init_df: DataFrame, + client: AsyncClient, + s3_file_df_connection: SparkS3, + s3_file_connection: S3, + prepare_postgres, + prepare_s3, + postgres_to_s3: Transfer, + target_file_format, + file_format_flavor: str, + strategy, + transformations, + expected_sql_transformation, + expected_extension: str, +): + format_name, format = target_file_format + target_path = f"/target/{format_name}/{file_format_flavor}" + _, fill_with_data = prepare_postgres + fill_with_data(init_df) + + init_df = expected_sql_transformation(init_df, "postgres") + + await run_transfer_and_verify(client, group_owner, postgres_to_s3.id, target_auth="s3") + + file_names = [file.name for file in s3_file_connection.list_dir(target_path) if file.is_file()] + verify_file_name_template(file_names, expected_extension) + + reader = FileDFReader( + connection=s3_file_df_connection, + format=format, + source_path=target_path, + df_schema=init_df.schema, + options={}, + ) + df = reader.run() + + df, init_df = cast_dataframe_types(df, init_df) + assert df.sort("id").collect() == init_df.sort("id").collect()