From d79c5ec981d7856b18c55cd276c691bfb145a27c Mon Sep 17 00:00:00 2001 From: Sai Ganesh Muthuraman Date: Thu, 26 Feb 2026 01:08:08 -0800 Subject: [PATCH] Decouple third_party/xprof/convert:repository from heavy google-cloud-cpp dependencies. PiperOrigin-RevId: 875576876 --- xprof/convert/BUILD | 42 +++--- xprof/convert/file_utils.cc | 156 ++++++---------------- xprof/convert/file_utils.h | 27 +--- xprof/convert/file_utils_gcs.cc | 157 +++++++++++++++++++++++ xprof/convert/file_utils_internal.h | 49 +++++++ xprof/convert/file_utils_test.cc | 116 ++++++----------- xprof/convert/storage_client_interface.h | 53 ++++++++ 7 files changed, 365 insertions(+), 235 deletions(-) create mode 100644 xprof/convert/file_utils_gcs.cc create mode 100644 xprof/convert/file_utils_internal.h create mode 100644 xprof/convert/storage_client_interface.h diff --git a/xprof/convert/BUILD b/xprof/convert/BUILD index 8d476c88..8bd4760d 100644 --- a/xprof/convert/BUILD +++ b/xprof/convert/BUILD @@ -1741,15 +1741,33 @@ cc_library( ], ) +cc_library( + name = "storage_client_interface", + hdrs = ["storage_client_interface.h"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + cc_library( name = "file_utils", - srcs = ["file_utils.cc"], - hdrs = ["file_utils.h"], + srcs = [ + "file_utils.cc", + "file_utils_gcs.cc", + ], + hdrs = [ + "file_utils.h", + "file_utils_internal.h", + ], deps = [ + ":storage_client_interface", "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud:google_cloud_cpp_common", "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/storage:google_cloud_cpp_storage", + "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", @@ -1764,10 +1782,10 @@ cc_test( srcs = ["file_utils_test.cc"], deps = [ ":file_utils", - "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud:google_cloud_cpp_common", - "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/storage:google_cloud_cpp_storage", - "@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/storage:google_cloud_cpp_storage_testing", + ":storage_client_interface", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", "@tsl//tsl/profiler/protobuf:xplane_proto_cc", "@xla//xla/tsl/lib/core:status_test_util", @@ -1888,20 +1906,6 @@ cc_library( ], ) -cc_test( - name = "dcn_utils_test", - srcs = ["dcn_utils_test.cc"], - deps = [ - ":dcn_utils", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@xla//xla/tsl/profiler/utils:tf_xplane_visitor", - "@xla//xla/tsl/profiler/utils:xplane_builder", - "@xla//xla/tsl/profiler/utils:xplane_schema", - "@xla//xla/tsl/profiler/utils:xplane_visitor", - ], -) - cc_library( name = "dcn_analysis", srcs = ["dcn_analysis.cc"], diff --git a/xprof/convert/file_utils.cc b/xprof/convert/file_utils.cc index b74cb219..1beaa64e 100644 --- a/xprof/convert/file_utils.cc +++ b/xprof/convert/file_utils.cc @@ -15,78 +15,52 @@ limitations under the License. #include "xprof/convert/file_utils.h" -#include #include #include #include -#include #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" -#include "absl/synchronization/mutex.h" #include "absl/time/clock.h" #include "absl/time/time.h" -#include "google/cloud/options.h" // from @com_github_googlecloudplatform_google_cloud_cpp -#include "google/cloud/storage/client.h" // from @com_github_googlecloudplatform_google_cloud_cpp -#include "google/cloud/storage/download_options.h" // from @com_github_googlecloudplatform_google_cloud_cpp -#include "google/cloud/storage/options.h" // from @com_github_googlecloudplatform_google_cloud_cpp #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/errors.h" -#include "xla/tsl/platform/threadpool.h" #include "tsl/platform/protobuf.h" +#include "xprof/convert/file_utils_internal.h" +#include "xprof/convert/storage_client_interface.h" namespace xprof { namespace { -namespace gcs = ::google::cloud::storage; - -// Number of concurrent threads for downloading. -constexpr int kNumThreads = 32; -// Minimum chunk size to justify parallelization (8 MB). -constexpr int64_t kMinChunkSize = 8 * 1024 * 1024; // Maximum size for a proto, roughly 2GB. constexpr int64_t kMaxProtoSize = 2LL * 1024 * 1024 * 1024; -tsl::thread::ThreadPool* GetGcsThreadPool() { - static tsl::thread::ThreadPool* pool = - new tsl::thread::ThreadPool(tsl::Env::Default(), "gcs_read", kNumThreads); - return pool; -} - -gcs::Client& GetGcsClient() { - static auto* client = []() { - auto options = google::cloud::Options{} - .set(kNumThreads) - .set(1024 * 1024); - return new gcs::Client(std::move(options)); - }(); - return *client; -} - } // namespace + namespace internal { absl::Status ParseGcsPath(absl::string_view fname, std::string* bucket, std::string* object) { absl::string_view path = fname; - const std::string gcs_prefix = "gs://"; - const std::string bigstore_prefix = "/bigstore/"; + constexpr absl::string_view kGcsPrefix = "gs://"; + constexpr absl::string_view kBigstorePrefix = "/bigstore/"; - if (absl::StartsWith(path, gcs_prefix)) { - path = absl::StripPrefix(path, gcs_prefix); - } else if (absl::StartsWith(path, bigstore_prefix)) { - path = absl::StripPrefix(path, bigstore_prefix); + if (absl::StartsWith(path, kGcsPrefix)) { + path = absl::StripPrefix(path, kGcsPrefix); + } else if (absl::StartsWith(path, kBigstorePrefix)) { + path = absl::StripPrefix(path, kBigstorePrefix); } else { return absl::InvalidArgumentError(absl::StrCat( "GCS path must start with 'gs://' or '/bigstore/': ", fname)); } - size_t slash_pos = path.find('/'); + const size_t slash_pos = path.find('/'); if (slash_pos == absl::string_view::npos || slash_pos == 0) { return absl::InvalidArgumentError( absl::StrCat("GCS path doesn't contain a bucket name: ", fname)); @@ -100,123 +74,70 @@ absl::Status ParseGcsPath(absl::string_view fname, std::string* bucket, return absl::OkStatus(); } -} // namespace internal -namespace { - -absl::Status DownloadConcurrently(gcs::Client& client, - const std::string& bucket, - const std::string& object, int64_t total_size, - std::string& contents) { - contents.resize(total_size); - - int64_t chunk_size = - std::max(kMinChunkSize, (total_size + kNumThreads - 1) / kNumThreads); - int num_chunks = (total_size + chunk_size - 1) / chunk_size; - - absl::Mutex mu; - absl::Status status = absl::OkStatus(); - - GetGcsThreadPool()->ParallelFor( - num_chunks, tsl::thread::ThreadPool::SchedulingParams::Fixed(1), - [&client, &bucket, &object, chunk_size, total_size, &contents, &mu, - &status](int64_t i, int64_t end_chunk) { - for (int64_t chunk_idx = i; chunk_idx < end_chunk; ++chunk_idx) { - { - absl::MutexLock lock(mu); - if (!status.ok()) return; - } - int64_t start = chunk_idx * chunk_size; - int64_t end = std::min(start + chunk_size, total_size); - - auto reader = - client.ReadObject(bucket, object, gcs::ReadRange(start, end)); - if (!reader) { - absl::MutexLock lock(mu); - status.Update(absl::InternalError(absl::StrCat( - "Failed to read range: ", reader.status().message()))); - return; - } - reader.read(&contents[start], end - start); - if (!reader.status().ok()) { - absl::MutexLock lock(mu); - status.Update(absl::DataLossError(absl::StrCat( - "Failed to read GCS range data: ", reader.status().message()))); - return; - } - } - }); - return status; -} - -} // namespace -namespace internal { - -absl::Status ReadBinaryProtoWithClient(gcs::Client& client, +absl::Status ReadBinaryProtoWithClient(StorageClientInterface& client, const std::string& fname, tsl::protobuf::MessageLite* proto) { - std::string bucket, object; + std::string bucket; + std::string object; TF_RETURN_IF_ERROR(ParseGcsPath(fname, &bucket, &object)); - // Get object metadata to find the size. - auto metadata = client.GetObjectMetadata(bucket, object); - if (!metadata) { + // Get object size. + const absl::StatusOr size_or = + client.GetObjectSize(bucket, object); + if (!size_or.ok()) { return absl::NotFoundError(absl::StrCat("Failed to get GCS metadata: ", - metadata.status().message())); + size_or.status().message())); } - int64_t total_size = static_cast(metadata->size()); + const std::uint64_t total_size = *size_or; if (total_size == 0) { proto->Clear(); return absl::OkStatus(); } - if (total_size > kMaxProtoSize) { + if (total_size > static_cast(kMaxProtoSize)) { return absl::FailedPreconditionError( absl::StrCat("File too large for a proto: ", total_size)); } std::string contents; - absl::Time start_download = absl::Now(); + contents.resize(total_size); + const absl::Time start_download = absl::Now(); TF_RETURN_IF_ERROR( - DownloadConcurrently(client, bucket, object, total_size, contents)); - absl::Time end_download = absl::Now(); + client.ReadObject(bucket, object, 0, total_size, &contents[0])); + const absl::Time end_download = absl::Now(); VLOG(1) << "Download from GCS took: " << end_download - start_download; - absl::Time start_parse = absl::Now(); + const absl::Time start_parse = absl::Now(); if (!proto->ParseFromString(contents)) { return absl::DataLossError( absl::StrCat("Can't parse ", fname, " as binary proto")); } - absl::Time end_parse = absl::Now(); + const absl::Time end_parse = absl::Now(); VLOG(1) << "Protobuf parsing took: " << end_parse - start_parse; return absl::OkStatus(); } absl::Status WriteBinaryProtoWithClient( - gcs::Client& client, const std::string& fname, + StorageClientInterface& client, const std::string& fname, const tsl::protobuf::MessageLite& proto) { - std::string bucket, object; + std::string bucket; + std::string object; TF_RETURN_IF_ERROR(ParseGcsPath(fname, &bucket, &object)); std::string contents; - absl::Time start_serialize = absl::Now(); + const absl::Time start_serialize = absl::Now(); if (!proto.SerializeToString(&contents)) { return absl::InternalError( absl::StrCat("Failed to serialize proto to string for ", fname)); } - absl::Time end_serialize = absl::Now(); + const absl::Time end_serialize = absl::Now(); LOG(INFO) << "Proto serialization took: " << end_serialize - start_serialize; - absl::Time start_upload = absl::Now(); - auto stream = client.WriteObject(bucket, object); - stream << contents; - stream.Close(); - absl::Time end_upload = absl::Now(); - if (!stream) { - return absl::InternalError(absl::StrCat( - "Failed to write to GCS: ", stream.metadata().status().message())); - } + const absl::Time start_upload = absl::Now(); + TF_RETURN_IF_ERROR(client.WriteObject(bucket, object, contents)); + const absl::Time end_upload = absl::Now(); LOG(INFO) << "Upload to GCS took: " << end_upload - start_upload; return absl::OkStatus(); } @@ -227,7 +148,8 @@ absl::Status ReadBinaryProto(const std::string& fname, tsl::protobuf::MessageLite* proto) { if (absl::StartsWith(fname, "gs://") || absl::StartsWith(fname, "/bigstore/")) { - return internal::ReadBinaryProtoWithClient(GetGcsClient(), fname, proto); + return internal::ReadBinaryProtoWithClient(internal::GetDefaultGcsClient(), + fname, proto); } return tsl::ReadBinaryProto(tsl::Env::Default(), fname, proto); @@ -241,8 +163,8 @@ absl::Status WriteBinaryProto(const std::string& fname, if (absl::StartsWith(fname, "/bigstore/")) { gcs_path = absl::StrCat("gs://", absl::StripPrefix(fname, "/bigstore/")); } - return internal::WriteBinaryProtoWithClient(GetGcsClient(), gcs_path, - proto); + return internal::WriteBinaryProtoWithClient(internal::GetDefaultGcsClient(), + gcs_path, proto); } return tsl::WriteBinaryProto(tsl::Env::Default(), fname, proto); diff --git a/xprof/convert/file_utils.h b/xprof/convert/file_utils.h index 1be17b82..e9e8424d 100644 --- a/xprof/convert/file_utils.h +++ b/xprof/convert/file_utils.h @@ -19,41 +19,18 @@ limitations under the License. #include #include "absl/status/status.h" -#include "absl/strings/string_view.h" -#include "google/cloud/storage/client.h" // from @com_github_googlecloudplatform_google_cloud_cpp #include "tsl/platform/protobuf.h" namespace xprof { -// Reads a binary proto from a GCS path (gs://bucket/object) using the -// google-cloud-cpp storage API. +// Reads a binary proto from a file path. Supports local and GCS paths. absl::Status ReadBinaryProto(const std::string& fname, tsl::protobuf::MessageLite* proto); -// Writes a binary proto to a GCS path (gs://bucket/object) using the -// google-cloud-cpp storage API. -// Falls back to tsl::WriteBinaryProto for non-GCS paths. +// Writes a binary proto to a file path. Supports local and GCS paths. absl::Status WriteBinaryProto(const std::string& fname, const tsl::protobuf::MessageLite& proto); -namespace internal { - -// Parses a GCS path. Supports gs:// and /bigstore/ prefixes. -absl::Status ParseGcsPath(absl::string_view fname, std::string* bucket, - std::string* object); - -// Internal implementation that takes a GCS client, used for testing. -absl::Status ReadBinaryProtoWithClient(google::cloud::storage::Client& client, - const std::string& fname, - tsl::protobuf::MessageLite* proto); - -// Internal implementation that takes a GCS client, used for testing. -absl::Status WriteBinaryProtoWithClient( - google::cloud::storage::Client& client, const std::string& fname, - const tsl::protobuf::MessageLite& proto); - -} // namespace internal - } // namespace xprof #endif // THIRD_PARTY_XPROF_CONVERT_FILE_UTILS_H_ diff --git a/xprof/convert/file_utils_gcs.cc b/xprof/convert/file_utils_gcs.cc new file mode 100644 index 00000000..a3341101 --- /dev/null +++ b/xprof/convert/file_utils_gcs.cc @@ -0,0 +1,157 @@ +/* Copyright 2026 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" +#include "google/cloud/options.h" // from @com_github_googlecloudplatform_google_cloud_cpp +#include "google/cloud/status.h" // from @com_github_googlecloudplatform_google_cloud_cpp +#include "google/cloud/status_or.h" // from @com_github_googlecloudplatform_google_cloud_cpp +#include "google/cloud/storage/client.h" // from @com_github_googlecloudplatform_google_cloud_cpp +#include "google/cloud/storage/download_options.h" // from @com_github_googlecloudplatform_google_cloud_cpp +#include "google/cloud/storage/object_metadata.h" // from @com_github_googlecloudplatform_google_cloud_cpp +#include "google/cloud/storage/object_read_stream.h" // from @com_github_googlecloudplatform_google_cloud_cpp +#include "google/cloud/storage/object_write_stream.h" // from @com_github_googlecloudplatform_google_cloud_cpp +#include "google/cloud/storage/options.h" // from @com_github_googlecloudplatform_google_cloud_cpp +#include "xla/tsl/platform/env.h" +#include "xla/tsl/platform/threadpool.h" +#include "xprof/convert/storage_client_interface.h" + +namespace xprof { +namespace { + +namespace gcs = ::google::cloud::storage; + +// Number of concurrent threads for downloading. +constexpr int kNumThreads = 32; +// Minimum chunk size to justify parallelization (8 MB). +constexpr int64_t kMinChunkSize = 8 * 1024 * 1024; +// Download buffer size for GCS client (1 MB). +constexpr int kDownloadBufferSize = 1024 * 1024; + +tsl::thread::ThreadPool& GetGcsThreadPool() { + static absl::NoDestructor pool( + tsl::Env::Default(), "gcs_read", kNumThreads); + return *pool; +} + +// Implementation of StorageClientInterface for GCS. +class GcsStorageClient : public internal::StorageClientInterface { + public: + explicit GcsStorageClient(gcs::Client client) : client_(std::move(client)) {} + + // Returns the size of the object in GCS. + absl::StatusOr GetObjectSize( + const std::string& bucket, const std::string& object) override { + const google::cloud::StatusOr metadata = + client_.GetObjectMetadata(bucket, object); + if (!metadata) { + if (metadata.status().code() == google::cloud::StatusCode::kNotFound) { + return absl::NotFoundError(metadata.status().message()); + } + return absl::InternalError(metadata.status().message()); + } + return metadata->size(); + } + + // Reads the object from GCS in parallel chunks. + absl::Status ReadObject(const std::string& bucket, const std::string& object, + std::uint64_t start, std::uint64_t end, + char* buffer) override { + const std::uint64_t total_size = end - start; + const std::uint64_t chunk_size = std::max( + kMinChunkSize, (total_size + kNumThreads - 1) / kNumThreads); + const int num_chunks = static_cast( + (total_size + chunk_size - 1) / chunk_size); + + absl::Mutex mu; + absl::Status status = absl::OkStatus(); + + GetGcsThreadPool().ParallelFor( + num_chunks, tsl::thread::ThreadPool::SchedulingParams::Fixed(1), + [this, &bucket, &object, start, chunk_size, total_size, buffer, &mu, + &status](int64_t i, int64_t end_chunk) { + for (int64_t chunk_idx = i; chunk_idx < end_chunk; ++chunk_idx) { + { + absl::MutexLock lock(&mu); + if (!status.ok()) return; + } + const std::uint64_t chunk_start = start + chunk_idx * chunk_size; + const std::uint64_t chunk_end = + std::min(chunk_start + chunk_size, start + total_size); + + gcs::ObjectReadStream reader = client_.ReadObject( + bucket, object, gcs::ReadRange(chunk_start, chunk_end)); + if (!reader) { + absl::MutexLock lock(&mu); + status.Update(absl::InternalError(absl::StrCat( + "Failed to read range: ", reader.status().message()))); + return; + } + reader.read(buffer + (chunk_start - start), + chunk_end - chunk_start); + if (!reader.status().ok()) { + absl::MutexLock lock(&mu); + status.Update(absl::DataLossError(absl::StrCat( + "Failed to read GCS range data: ", + reader.status().message()))); + return; + } + } + }); + return status; + } + + // Writes the contents to the object in GCS. + absl::Status WriteObject(const std::string& bucket, const std::string& object, + const std::string& contents) override { + gcs::ObjectWriteStream stream = client_.WriteObject(bucket, object); + stream << contents; + stream.Close(); + if (!stream) { + return absl::InternalError(absl::StrCat( + "Failed to write to GCS: ", stream.metadata().status().message())); + } + return absl::OkStatus(); + } + + private: + gcs::Client client_; +}; + +} // namespace + +namespace internal { + +StorageClientInterface& GetDefaultGcsClient() { + static absl::NoDestructor client([] { + google::cloud::Options options = + google::cloud::Options{} + .set(kNumThreads) + .set(kDownloadBufferSize); + return GcsStorageClient(gcs::Client(std::move(options))); + }()); + return *client; +} + +} // namespace internal +} // namespace xprof diff --git a/xprof/convert/file_utils_internal.h b/xprof/convert/file_utils_internal.h new file mode 100644 index 00000000..58f4f2b1 --- /dev/null +++ b/xprof/convert/file_utils_internal.h @@ -0,0 +1,49 @@ +/* Copyright 2026 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_XPROF_CONVERT_FILE_UTILS_INTERNAL_H_ +#define THIRD_PARTY_XPROF_CONVERT_FILE_UTILS_INTERNAL_H_ + +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tsl/platform/protobuf.h" +#include "xprof/convert/storage_client_interface.h" + +namespace xprof { +namespace internal { + +// Parses a GCS path. Supports gs:// and /bigstore/ prefixes. +absl::Status ParseGcsPath(absl::string_view fname, std::string* bucket, + std::string* object); + +// Internal implementation that takes a storage client interface. +absl::Status ReadBinaryProtoWithClient(StorageClientInterface& client, + const std::string& fname, + tsl::protobuf::MessageLite* proto); + +// Internal implementation that takes a storage client interface. +absl::Status WriteBinaryProtoWithClient( + StorageClientInterface& client, const std::string& fname, + const tsl::protobuf::MessageLite& proto); + +// Returns the default GCS client implementation. Defined in file_utils_gcs.cc. +StorageClientInterface& GetDefaultGcsClient(); + +} // namespace internal +} // namespace xprof + +#endif // THIRD_PARTY_XPROF_CONVERT_FILE_UTILS_INTERNAL_H_ diff --git a/xprof/convert/file_utils_test.cc b/xprof/convert/file_utils_test.cc index d20360b0..9722fcf7 100644 --- a/xprof/convert/file_utils_test.cc +++ b/xprof/convert/file_utils_test.cc @@ -16,23 +16,18 @@ limitations under the License. #include "xprof/convert/file_utils.h" #include -#include -#include +#include #include -#include #include "testing/base/public/gmock.h" #include "" #include "absl/status/status.h" -#include "google/cloud/status_or.h" // from @com_github_googlecloudplatform_google_cloud_cpp -#include "google/cloud/storage/client.h" // from @com_github_googlecloudplatform_google_cloud_cpp -#include "google/cloud/storage/internal/http_response.h" // from @com_github_googlecloudplatform_google_cloud_cpp -#include "google/cloud/storage/internal/object_read_source.h" // from @com_github_googlecloudplatform_google_cloud_cpp -#include "google/cloud/storage/internal/object_requests.h" // from @com_github_googlecloudplatform_google_cloud_cpp -#include "google/cloud/storage/object_metadata.h" // from @com_github_googlecloudplatform_google_cloud_cpp -#include "google/cloud/storage/testing/mock_client.h" // from @com_github_googlecloudplatform_google_cloud_cpp +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/profiler/protobuf/xplane.pb.h" +#include "xprof/convert/file_utils_internal.h" +#include "xprof/convert/storage_client_interface.h" namespace xprof { namespace { @@ -43,11 +38,25 @@ using ::testing::Return; using ::xprof::internal::ParseGcsPath; using ::xprof::internal::ReadBinaryProtoWithClient; using ::xprof::internal::WriteBinaryProtoWithClient; -namespace gcs = ::google::cloud::storage; -namespace gcs_testing = ::google::cloud::storage::testing; + +class MockStorageClient : public internal::StorageClientInterface { + public: + MOCK_METHOD(absl::StatusOr, GetObjectSize, + (const std::string& bucket, const std::string& object), + (override)); + MOCK_METHOD(absl::Status, ReadObject, + (const std::string& bucket, const std::string& object, + std::uint64_t start, std::uint64_t end, char* buffer), + (override)); + MOCK_METHOD(absl::Status, WriteObject, + (const std::string& bucket, const std::string& object, + const std::string& contents), + (override)); +}; TEST(FileUtilsTest, ParseGcsPath_GsPrefix) { - std::string bucket, object; + std::string bucket; + std::string object; TF_EXPECT_OK( ParseGcsPath("gs://my-bucket/path/to/object.hlo", &bucket, &object)); EXPECT_THAT(bucket, Eq("my-bucket")); @@ -55,7 +64,8 @@ TEST(FileUtilsTest, ParseGcsPath_GsPrefix) { } TEST(FileUtilsTest, ParseGcsPath_BigstorePrefix) { - std::string bucket, object; + std::string bucket; + std::string object; TF_EXPECT_OK( ParseGcsPath("/bigstore/my-bucket/path/to/object.hlo", &bucket, &object)); EXPECT_THAT(bucket, Eq("my-bucket")); @@ -63,86 +73,44 @@ TEST(FileUtilsTest, ParseGcsPath_BigstorePrefix) { } TEST(FileUtilsTest, ParseGcsPath_Invalid) { - std::string bucket, object; + std::string bucket; + std::string object; EXPECT_FALSE(ParseGcsPath("s3://my-bucket/object", &bucket, &object).ok()); EXPECT_FALSE(ParseGcsPath("gs://my-bucket", &bucket, &object).ok()); EXPECT_FALSE(ParseGcsPath("gs:///object", &bucket, &object).ok()); } TEST(FileUtilsTest, ReadBinaryProtoWithClient_Success) { - auto mock = std::make_shared(); - gcs::Client client = gcs_testing::UndecoratedClientFromMock(mock); - - std::string bucket = "bucket"; - std::string object = "object"; - std::string content = "XSpace content"; - - gcs::ObjectMetadata metadata; - metadata.set_size(content.size()); - - EXPECT_CALL(*mock, GetObjectMetadata(_)) - .WillOnce(Return(google::cloud::StatusOr(metadata))); - - auto mock_source = std::make_unique(); - EXPECT_CALL(*mock_source, IsOpen()).WillRepeatedly(Return(true)); - EXPECT_CALL(*mock_source, Read(_, _)) - .WillOnce([content](char* buf, std::size_t n) { - std::copy(content.begin(), content.end(), buf); - return gcs::internal::ReadSourceResult{ - content.size(), gcs::internal::HttpResponse{200, "", {}}}; - }); + MockStorageClient client; + constexpr absl::string_view kContent = "XSpace content"; - EXPECT_CALL(*mock, ReadObject(_)) - .WillOnce(Return(google::cloud::StatusOr< - std::unique_ptr>( - std::move(mock_source)))); + EXPECT_CALL(client, GetObjectSize("bucket", "object")) + .WillOnce(Return(kContent.size())); + + EXPECT_CALL(client, ReadObject("bucket", "object", 0, kContent.size(), _)) + .WillOnce([kContent](const std::string&, const std::string&, + std::uint64_t, std::uint64_t, char* buf) { + std::copy(kContent.begin(), kContent.end(), buf); + return absl::OkStatus(); + }); tensorflow::profiler::XSpace xspace; - absl::Status status = + const absl::Status status = ReadBinaryProtoWithClient(client, "gs://bucket/object", &xspace); - // The ReadBinaryProtoWithClient function first reads the data and then tries - // to parse it as a binary proto. Since "XSpace content" is not a valid - // serialized proto, the parsing will fail with kDataLoss. This expectation - // confirms that the download was successful and the code proceeded to the - // parsing stage. + // Expect kDataLoss because "XSpace content" is not a valid serialized proto. EXPECT_THAT(status.code(), Eq(absl::StatusCode::kDataLoss)); } TEST(FileUtilsTest, WriteBinaryProtoWithClient_Success) { - auto mock = std::make_shared(); - gcs::Client client = gcs_testing::UndecoratedClientFromMock(mock); - - std::string bucket = "bucket"; - std::string object = "object"; + MockStorageClient client; tensorflow::profiler::XSpace xspace; xspace.add_hostnames("test-host"); std::string expected_contents; xspace.SerializeToString(&expected_contents); - EXPECT_CALL(*mock, CreateResumableUpload(_)) - .WillOnce([bucket, - object](gcs::internal::ResumableUploadRequest const& request) { - EXPECT_EQ(request.bucket_name(), bucket); - EXPECT_EQ(request.object_name(), object); - return google::cloud::StatusOr< - gcs::internal::CreateResumableUploadResponse>( - gcs::internal::CreateResumableUploadResponse{"session-id"}); - }); - - EXPECT_CALL(*mock, UploadChunk(_)) - .WillOnce([expected_contents]( - gcs::internal::UploadChunkRequest const& request) { - std::string actual_payload; - for (auto const& b : request.payload()) { - actual_payload.append(static_cast(b.data()), b.size()); - } - EXPECT_EQ(actual_payload, expected_contents); - return google::cloud::StatusOr< - gcs::internal::QueryResumableUploadResponse>( - gcs::internal::QueryResumableUploadResponse{ - expected_contents.size(), gcs::ObjectMetadata{}}); - }); + EXPECT_CALL(client, WriteObject("bucket", "object", expected_contents)) + .WillOnce(Return(absl::OkStatus())); TF_EXPECT_OK( WriteBinaryProtoWithClient(client, "gs://bucket/object", xspace)); diff --git a/xprof/convert/storage_client_interface.h b/xprof/convert/storage_client_interface.h new file mode 100644 index 00000000..1d7bf8f7 --- /dev/null +++ b/xprof/convert/storage_client_interface.h @@ -0,0 +1,53 @@ +/* Copyright 2026 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_XPROF_CONVERT_STORAGE_CLIENT_INTERFACE_H_ +#define THIRD_PARTY_XPROF_CONVERT_STORAGE_CLIENT_INTERFACE_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" + +namespace xprof { +namespace internal { + +// Interface for storage operations to decouple from heavy cloud headers +// and facilitate mocking in tests. +class StorageClientInterface { + public: + virtual ~StorageClientInterface() = default; + + // Returns the size of the object in bytes. + virtual absl::StatusOr GetObjectSize( + const std::string& bucket, const std::string& object) = 0; + + // Reads a range of bytes from the object into the buffer. + virtual absl::Status ReadObject(const std::string& bucket, + const std::string& object, + std::uint64_t start, std::uint64_t end, + char* buffer) = 0; + + // Writes the entire contents to the object. + virtual absl::Status WriteObject(const std::string& bucket, + const std::string& object, + const std::string& contents) = 0; +}; + +} // namespace internal +} // namespace xprof + +#endif // THIRD_PARTY_XPROF_CONVERT_STORAGE_CLIENT_INTERFACE_H_