Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions xprof/convert/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1741,15 +1741,33 @@ cc_library(
],
)

cc_library(
name = "storage_client_interface",
hdrs = ["storage_client_interface.h"],
deps = [
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
)

cc_library(
name = "file_utils",
srcs = ["file_utils.cc"],
hdrs = ["file_utils.h"],
srcs = [
"file_utils.cc",
"file_utils_gcs.cc",
],
hdrs = [
"file_utils.h",
"file_utils_internal.h",
],
deps = [
":storage_client_interface",
"@com_github_googlecloudplatform_google_cloud_cpp//google/cloud:google_cloud_cpp_common",
"@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/storage:google_cloud_cpp_storage",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
Expand All @@ -1764,10 +1782,10 @@ cc_test(
srcs = ["file_utils_test.cc"],
deps = [
":file_utils",
"@com_github_googlecloudplatform_google_cloud_cpp//google/cloud:google_cloud_cpp_common",
"@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/storage:google_cloud_cpp_storage",
"@com_github_googlecloudplatform_google_cloud_cpp//google/cloud/storage:google_cloud_cpp_storage_testing",
":storage_client_interface",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/profiler/protobuf:xplane_proto_cc",
"@xla//xla/tsl/lib/core:status_test_util",
Expand Down Expand Up @@ -1888,20 +1906,6 @@ cc_library(
],
)

cc_test(
name = "dcn_utils_test",
srcs = ["dcn_utils_test.cc"],
deps = [
":dcn_utils",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
"@xla//xla/tsl/profiler/utils:tf_xplane_visitor",
"@xla//xla/tsl/profiler/utils:xplane_builder",
"@xla//xla/tsl/profiler/utils:xplane_schema",
"@xla//xla/tsl/profiler/utils:xplane_visitor",
],
)

cc_library(
name = "dcn_analysis",
srcs = ["dcn_analysis.cc"],
Expand Down
156 changes: 39 additions & 117 deletions xprof/convert/file_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,78 +15,52 @@ limitations under the License.

#include "xprof/convert/file_utils.h"

#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <string>
#include <utility>

#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/strings/strip.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "google/cloud/options.h" // from @com_github_googlecloudplatform_google_cloud_cpp
#include "google/cloud/storage/client.h" // from @com_github_googlecloudplatform_google_cloud_cpp
#include "google/cloud/storage/download_options.h" // from @com_github_googlecloudplatform_google_cloud_cpp
#include "google/cloud/storage/options.h" // from @com_github_googlecloudplatform_google_cloud_cpp
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/threadpool.h"
#include "tsl/platform/protobuf.h"
#include "xprof/convert/file_utils_internal.h"
#include "xprof/convert/storage_client_interface.h"

namespace xprof {

namespace {

namespace gcs = ::google::cloud::storage;

// Number of concurrent threads for downloading.
constexpr int kNumThreads = 32;
// Minimum chunk size to justify parallelization (8 MB).
constexpr int64_t kMinChunkSize = 8 * 1024 * 1024;
// Maximum size for a proto, roughly 2GB.
constexpr int64_t kMaxProtoSize = 2LL * 1024 * 1024 * 1024;

tsl::thread::ThreadPool* GetGcsThreadPool() {
static tsl::thread::ThreadPool* pool =
new tsl::thread::ThreadPool(tsl::Env::Default(), "gcs_read", kNumThreads);
return pool;
}

gcs::Client& GetGcsClient() {
static auto* client = []() {
auto options = google::cloud::Options{}
.set<gcs::ConnectionPoolSizeOption>(kNumThreads)
.set<gcs::DownloadBufferSizeOption>(1024 * 1024);
return new gcs::Client(std::move(options));
}();
return *client;
}

} // namespace

namespace internal {

absl::Status ParseGcsPath(absl::string_view fname, std::string* bucket,
std::string* object) {
absl::string_view path = fname;
const std::string gcs_prefix = "gs://";
const std::string bigstore_prefix = "/bigstore/";
constexpr absl::string_view kGcsPrefix = "gs://";
constexpr absl::string_view kBigstorePrefix = "/bigstore/";

if (absl::StartsWith(path, gcs_prefix)) {
path = absl::StripPrefix(path, gcs_prefix);
} else if (absl::StartsWith(path, bigstore_prefix)) {
path = absl::StripPrefix(path, bigstore_prefix);
if (absl::StartsWith(path, kGcsPrefix)) {
path = absl::StripPrefix(path, kGcsPrefix);
} else if (absl::StartsWith(path, kBigstorePrefix)) {
path = absl::StripPrefix(path, kBigstorePrefix);
} else {
return absl::InvalidArgumentError(absl::StrCat(
"GCS path must start with 'gs://' or '/bigstore/': ", fname));
}

size_t slash_pos = path.find('/');
const size_t slash_pos = path.find('/');
if (slash_pos == absl::string_view::npos || slash_pos == 0) {
return absl::InvalidArgumentError(
absl::StrCat("GCS path doesn't contain a bucket name: ", fname));
Expand All @@ -100,123 +74,70 @@ absl::Status ParseGcsPath(absl::string_view fname, std::string* bucket,
return absl::OkStatus();
}

} // namespace internal
namespace {

absl::Status DownloadConcurrently(gcs::Client& client,
const std::string& bucket,
const std::string& object, int64_t total_size,
std::string& contents) {
contents.resize(total_size);

int64_t chunk_size =
std::max(kMinChunkSize, (total_size + kNumThreads - 1) / kNumThreads);
int num_chunks = (total_size + chunk_size - 1) / chunk_size;

absl::Mutex mu;
absl::Status status = absl::OkStatus();

GetGcsThreadPool()->ParallelFor(
num_chunks, tsl::thread::ThreadPool::SchedulingParams::Fixed(1),
[&client, &bucket, &object, chunk_size, total_size, &contents, &mu,
&status](int64_t i, int64_t end_chunk) {
for (int64_t chunk_idx = i; chunk_idx < end_chunk; ++chunk_idx) {
{
absl::MutexLock lock(mu);
if (!status.ok()) return;
}
int64_t start = chunk_idx * chunk_size;
int64_t end = std::min(start + chunk_size, total_size);

auto reader =
client.ReadObject(bucket, object, gcs::ReadRange(start, end));
if (!reader) {
absl::MutexLock lock(mu);
status.Update(absl::InternalError(absl::StrCat(
"Failed to read range: ", reader.status().message())));
return;
}
reader.read(&contents[start], end - start);
if (!reader.status().ok()) {
absl::MutexLock lock(mu);
status.Update(absl::DataLossError(absl::StrCat(
"Failed to read GCS range data: ", reader.status().message())));
return;
}
}
});
return status;
}

} // namespace
namespace internal {

absl::Status ReadBinaryProtoWithClient(gcs::Client& client,
absl::Status ReadBinaryProtoWithClient(StorageClientInterface& client,
const std::string& fname,
tsl::protobuf::MessageLite* proto) {
std::string bucket, object;
std::string bucket;
std::string object;
TF_RETURN_IF_ERROR(ParseGcsPath(fname, &bucket, &object));

// Get object metadata to find the size.
auto metadata = client.GetObjectMetadata(bucket, object);
if (!metadata) {
// Get object size.
const absl::StatusOr<std::uint64_t> size_or =
client.GetObjectSize(bucket, object);
if (!size_or.ok()) {
return absl::NotFoundError(absl::StrCat("Failed to get GCS metadata: ",
metadata.status().message()));
size_or.status().message()));
}

int64_t total_size = static_cast<int64_t>(metadata->size());
const std::uint64_t total_size = *size_or;
if (total_size == 0) {
proto->Clear();
return absl::OkStatus();
}

if (total_size > kMaxProtoSize) {
if (total_size > static_cast<std::uint64_t>(kMaxProtoSize)) {
return absl::FailedPreconditionError(
absl::StrCat("File too large for a proto: ", total_size));
}

std::string contents;
absl::Time start_download = absl::Now();
contents.resize(total_size);
const absl::Time start_download = absl::Now();
TF_RETURN_IF_ERROR(
DownloadConcurrently(client, bucket, object, total_size, contents));
absl::Time end_download = absl::Now();
client.ReadObject(bucket, object, 0, total_size, &contents[0]));
const absl::Time end_download = absl::Now();
VLOG(1) << "Download from GCS took: " << end_download - start_download;

absl::Time start_parse = absl::Now();
const absl::Time start_parse = absl::Now();
if (!proto->ParseFromString(contents)) {
return absl::DataLossError(
absl::StrCat("Can't parse ", fname, " as binary proto"));
}
absl::Time end_parse = absl::Now();
const absl::Time end_parse = absl::Now();
VLOG(1) << "Protobuf parsing took: " << end_parse - start_parse;

return absl::OkStatus();
}

absl::Status WriteBinaryProtoWithClient(
gcs::Client& client, const std::string& fname,
StorageClientInterface& client, const std::string& fname,
const tsl::protobuf::MessageLite& proto) {
std::string bucket, object;
std::string bucket;
std::string object;
TF_RETURN_IF_ERROR(ParseGcsPath(fname, &bucket, &object));

std::string contents;
absl::Time start_serialize = absl::Now();
const absl::Time start_serialize = absl::Now();
if (!proto.SerializeToString(&contents)) {
return absl::InternalError(
absl::StrCat("Failed to serialize proto to string for ", fname));
}
absl::Time end_serialize = absl::Now();
const absl::Time end_serialize = absl::Now();
LOG(INFO) << "Proto serialization took: " << end_serialize - start_serialize;

absl::Time start_upload = absl::Now();
auto stream = client.WriteObject(bucket, object);
stream << contents;
stream.Close();
absl::Time end_upload = absl::Now();
if (!stream) {
return absl::InternalError(absl::StrCat(
"Failed to write to GCS: ", stream.metadata().status().message()));
}
const absl::Time start_upload = absl::Now();
TF_RETURN_IF_ERROR(client.WriteObject(bucket, object, contents));
const absl::Time end_upload = absl::Now();
LOG(INFO) << "Upload to GCS took: " << end_upload - start_upload;
return absl::OkStatus();
}
Expand All @@ -227,7 +148,8 @@ absl::Status ReadBinaryProto(const std::string& fname,
tsl::protobuf::MessageLite* proto) {
if (absl::StartsWith(fname, "gs://") ||
absl::StartsWith(fname, "/bigstore/")) {
return internal::ReadBinaryProtoWithClient(GetGcsClient(), fname, proto);
return internal::ReadBinaryProtoWithClient(internal::GetDefaultGcsClient(),
fname, proto);
}

return tsl::ReadBinaryProto(tsl::Env::Default(), fname, proto);
Expand All @@ -241,8 +163,8 @@ absl::Status WriteBinaryProto(const std::string& fname,
if (absl::StartsWith(fname, "/bigstore/")) {
gcs_path = absl::StrCat("gs://", absl::StripPrefix(fname, "/bigstore/"));
}
return internal::WriteBinaryProtoWithClient(GetGcsClient(), gcs_path,
proto);
return internal::WriteBinaryProtoWithClient(internal::GetDefaultGcsClient(),
gcs_path, proto);
}

return tsl::WriteBinaryProto(tsl::Env::Default(), fname, proto);
Expand Down
27 changes: 2 additions & 25 deletions xprof/convert/file_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,41 +19,18 @@ limitations under the License.
#include <string>

#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "google/cloud/storage/client.h" // from @com_github_googlecloudplatform_google_cloud_cpp
#include "tsl/platform/protobuf.h"

namespace xprof {

// Reads a binary proto from a GCS path (gs://bucket/object) using the
// google-cloud-cpp storage API.
// Reads a binary proto from a file path. Supports local and GCS paths.
absl::Status ReadBinaryProto(const std::string& fname,
tsl::protobuf::MessageLite* proto);

// Writes a binary proto to a GCS path (gs://bucket/object) using the
// google-cloud-cpp storage API.
// Falls back to tsl::WriteBinaryProto for non-GCS paths.
// Writes a binary proto to a file path. Supports local and GCS paths.
absl::Status WriteBinaryProto(const std::string& fname,
const tsl::protobuf::MessageLite& proto);

namespace internal {

// Parses a GCS path. Supports gs:// and /bigstore/ prefixes.
absl::Status ParseGcsPath(absl::string_view fname, std::string* bucket,
std::string* object);

// Internal implementation that takes a GCS client, used for testing.
absl::Status ReadBinaryProtoWithClient(google::cloud::storage::Client& client,
const std::string& fname,
tsl::protobuf::MessageLite* proto);

// Internal implementation that takes a GCS client, used for testing.
absl::Status WriteBinaryProtoWithClient(
google::cloud::storage::Client& client, const std::string& fname,
const tsl::protobuf::MessageLite& proto);

} // namespace internal

} // namespace xprof

#endif // THIRD_PARTY_XPROF_CONVERT_FILE_UTILS_H_
Loading