diff --git a/et_replay/comm/commsTraceParser.py b/et_replay/comm/commsTraceParser.py index d0cba6f2..a466aa24 100644 --- a/et_replay/comm/commsTraceParser.py +++ b/et_replay/comm/commsTraceParser.py @@ -7,11 +7,12 @@ import math -from et_replay import ExecutionTrace from et_replay.comm import comms_utils from et_replay.comm.backend.base_backend import supportedP2pOps from et_replay.comm.comms_utils import commsArgs +from et_replay.execution_trace import ExecutionTrace + logger = logging.getLogger(__name__) diff --git a/et_replay/comm/comms_utils.py b/et_replay/comm/comms_utils.py index 27cc00d1..62b843dd 100644 --- a/et_replay/comm/comms_utils.py +++ b/et_replay/comm/comms_utils.py @@ -20,7 +20,7 @@ from typing import Any try: - from param_bench.et_replay.vendor_internals import ( + from et_replay.vendor_internal.fb_internal import ( initialize_collectiveArgs_internal, remove_quantization_handlers, ) diff --git a/et_replay/comm/param_profile.py b/et_replay/comm/param_profile.py index a16bdf5e..5f109c2e 100644 --- a/et_replay/comm/param_profile.py +++ b/et_replay/comm/param_profile.py @@ -8,7 +8,7 @@ import logging import time from dataclasses import dataclass -from typing import Any, Optional +from typing import Any from torch.autograd.profiler import record_function diff --git a/et_replay/tools/comm_replay.py b/et_replay/tools/comm_replay.py index 30d0e78e..dd27ad40 100644 --- a/et_replay/tools/comm_replay.py +++ b/et_replay/tools/comm_replay.py @@ -30,7 +30,7 @@ from torch.profiler import ProfilerActivity try: - from param_bench.et_replay.vendor_internal.fb_internal import ( + from et_replay.vendor_internal.fb_internal import ( get_fb_profiler_activities, get_fb_profiler_trace_handler, ) diff --git a/et_replay/tools/et_replay.py b/et_replay/tools/et_replay.py index 60fbb30b..7dab6ba6 100644 --- a/et_replay/tools/et_replay.py +++ b/et_replay/tools/et_replay.py @@ -185,7 +185,7 @@ def __init__(self): self.profile_step_label = "ProfilerStep#" try: - from param_bench.et_replay.vendor_internal.fb_internal import ( + from et_replay.vendor_internal.fb_internal import ( add_internal_parallel_nodes_parents, ) except ImportError: @@ -271,9 +271,7 @@ def initBench(self): # Input et trace should be explicitly specified after --input. if "://" in self.args.input: try: - from param_bench.et_replay.vendor_internal.fb_internal import ( - read_remote_trace, - ) + from et_replay.vendor_internal.fb_internal import read_remote_trace except ImportError: logger.info("FB internals not present") exit(1) @@ -300,7 +298,7 @@ def initBench(self): # Different processes should read different traces based on global_rank_id. if "://" in self.args.trace_path: try: - from param_bench.et_replay.vendor_internal.fb_internal import ( + from et_replay.vendor_internal.fb_internal import ( read_remote_skip_node_file, read_remote_trace, ) @@ -1848,9 +1846,7 @@ def run_iter(iter): end_time = datetime.now() try: - from param_bench.et_replay.vendor_internal.fb_internal import ( - generate_query_url, - ) + from et_replay.vendor_internal.fb_internal import generate_query_url except ImportError: logger.info("FB internals not present") else: diff --git a/et_replay/utils.py b/et_replay/utils.py index a3069b97..7821d6ae 100644 --- a/et_replay/utils.py +++ b/et_replay/utils.py @@ -4,9 +4,9 @@ import logging import os import uuid -from typing import Any, Dict +from typing import Any -from et_replay import ExecutionTrace +from et_replay.execution_trace import ExecutionTrace def get_tmp_trace_filename() -> str: diff --git a/train/comms/pt/commsTraceParser.py b/train/comms/pt/commsTraceParser.py index a0471250..2b4935d9 100644 --- a/train/comms/pt/commsTraceParser.py +++ b/train/comms/pt/commsTraceParser.py @@ -3,7 +3,7 @@ import json -from et_replay import ExecutionTrace +from et_replay.execution_trace import ExecutionTrace from param_bench.train.comms.pt import comms_utils from param_bench.train.comms.pt.comms_utils import commsArgs