diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index e53acebf6..e8d1ce862 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -30,13 +30,13 @@ jobs: - name: Checkout uses: actions/checkout@v3 - name: Initialize CodeQL - uses: github/codeql-action/init@v2 + uses: github/codeql-action/init@v3 with: languages: ${{ matrix.language }} - name: Autobuild - uses: github/codeql-action/autobuild@v2 + uses: github/codeql-action/autobuild@v3 - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v2 + uses: github/codeql-action/analyze@v3 analyze-cpp: name: CodeQL analyze cpp runs-on: ubuntu-latest @@ -52,12 +52,16 @@ jobs: - name: Install Dependency run: | DEBIAN_FRONTEND=noninteractive apt-get update - DEBIAN_FRONTEND=noninteractive apt-get install -y ffmpeg libavcodec-dev libavformat-dev libavutil-dev libswresample-dev sudo + DEBIAN_FRONTEND=noninteractive apt-get install -y ffmpeg libavcodec-dev libavformat-dev libavutil-dev libswresample-dev sudo build-essential + - name: Setup CMake + uses: lukka/get-cmake@latest + with: + cmakeVersion: '3.20.0' - name: Initialize CodeQL - uses: github/codeql-action/init@v2 + uses: github/codeql-action/init@v3 with: languages: cpp - name: Build run: make cppbuild -j - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v2 + uses: github/codeql-action/analyze@v3 diff --git a/.gitignore b/.gitignore index 5888455a8..97b51dcd0 100644 --- a/.gitignore +++ b/.gitignore @@ -151,6 +151,9 @@ cython_debug/ *.userosscache *.sln.docstates +# Build temporary files +compile_commands.json + # Build results [Dd]ebug/ [Dd]ebugPublic/ diff --git a/.gitmodules b/.gitmodules index e7d2af022..03afed789 100644 --- a/.gitmodules +++ b/.gitmodules @@ -33,3 +33,6 @@ [submodule "third_party/nvbandwidth"] path = third_party/nvbandwidth url = https://github.com/NVIDIA/nvbandwidth.git +[submodule "third_party/nvbench"] + path = third_party/nvbench + url = https://github.com/NVIDIA/nvbench.git diff --git a/dockerfile/cuda12.8.dockerfile b/dockerfile/cuda12.8.dockerfile index 595624003..56156a00a 100644 --- a/dockerfile/cuda12.8.dockerfile +++ b/dockerfile/cuda12.8.dockerfile @@ -61,6 +61,27 @@ RUN apt-get update && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* /tmp/* +# Install CMake 3.30.4 for nvbench compatibility +RUN apt-get update && \ + apt-get remove -y cmake cmake-data && \ + apt-get autoremove -y && \ + cd /tmp && \ + ARCH=$(uname -m) && \ + case ${ARCH} in \ + "aarch64") CMAKE_ARCH="aarch64" ;; \ + "x86_64") CMAKE_ARCH="x86_64" ;; \ + "arm64") CMAKE_ARCH="aarch64" ;; \ + *) CMAKE_ARCH="x86_64" ;; \ + esac && \ + echo "Detected architecture: ${ARCH}, using CMAKE_ARCH: ${CMAKE_ARCH}" && \ + wget -q https://github.com/Kitware/CMake/releases/download/v3.30.4/cmake-3.30.4-linux-${CMAKE_ARCH}.tar.gz && \ + tar -xzf cmake-3.30.4-linux-${CMAKE_ARCH}.tar.gz && \ + mv cmake-3.30.4-linux-${CMAKE_ARCH} /opt/cmake && \ + ln -sf /opt/cmake/bin/* /usr/local/bin/ && \ + rm -rf cmake-3.30.4-linux-${CMAKE_ARCH}* && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + ARG NUM_MAKE_JOBS= ARG TARGETPLATFORM ARG TARGETARCH @@ -161,7 +182,7 @@ ADD dockerfile/etc /opt/microsoft/ WORKDIR ${SB_HOME} ADD third_party third_party -RUN make -C third_party cuda_with_msccl +RUN make -C third_party cuda_with_msccl cuda_nvbench ADD . . RUN python3 -m pip install --upgrade setuptools==70.3.0 && \ diff --git a/dockerfile/cuda12.9.dockerfile b/dockerfile/cuda12.9.dockerfile index 29804506c..1b0352eb1 100644 --- a/dockerfile/cuda12.9.dockerfile +++ b/dockerfile/cuda12.9.dockerfile @@ -62,6 +62,28 @@ RUN apt-get update && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* /tmp/* +# Install CMake 3.30.4 for nvbench compatibility +RUN apt-get update && \ + apt-get remove -y cmake cmake-data && \ + apt-get autoremove -y && \ + cd /tmp && \ + ARCH=$(uname -m) && \ + case ${ARCH} in \ + "aarch64") CMAKE_ARCH="aarch64" ;; \ + "x86_64") CMAKE_ARCH="x86_64" ;; \ + "arm64") CMAKE_ARCH="aarch64" ;; \ + *) CMAKE_ARCH="x86_64" ;; \ + esac && \ + echo "Detected architecture: ${ARCH}, using CMAKE_ARCH: ${CMAKE_ARCH}" && \ + wget -q https://github.com/Kitware/CMake/releases/download/v3.30.4/cmake-3.30.4-linux-${CMAKE_ARCH}.tar.gz && \ + tar -xzf cmake-3.30.4-linux-${CMAKE_ARCH}.tar.gz && \ + mv cmake-3.30.4-linux-${CMAKE_ARCH} /opt/cmake && \ + ln -sf /opt/cmake/bin/* /usr/local/bin/ && \ + rm -rf cmake-3.30.4-linux-${CMAKE_ARCH}* && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + + ARG NUM_MAKE_JOBS= ARG TARGETPLATFORM ARG TARGETARCH @@ -162,7 +184,7 @@ ADD dockerfile/etc /opt/microsoft/ WORKDIR ${SB_HOME} ADD third_party third_party -RUN make -C third_party cuda_with_msccl +RUN make -C third_party cuda_with_msccl cuda_nvbench ADD . . RUN python3 -m pip install --upgrade setuptools==78.1.0 && \ diff --git a/dockerfile/cuda13.0.dockerfile b/dockerfile/cuda13.0.dockerfile index 25fc7f9ed..6a6f88f7f 100644 --- a/dockerfile/cuda13.0.dockerfile +++ b/dockerfile/cuda13.0.dockerfile @@ -62,6 +62,27 @@ RUN apt-get update && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* /tmp/* +# Install CMake 3.30.4 for nvbench compatibility +RUN apt-get update && \ + apt-get remove -y cmake cmake-data && \ + apt-get autoremove -y && \ + cd /tmp && \ + ARCH=$(uname -m) && \ + case ${ARCH} in \ + "aarch64") CMAKE_ARCH="aarch64" ;; \ + "x86_64") CMAKE_ARCH="x86_64" ;; \ + "arm64") CMAKE_ARCH="aarch64" ;; \ + *) CMAKE_ARCH="x86_64" ;; \ + esac && \ + echo "Detected architecture: ${ARCH}, using CMAKE_ARCH: ${CMAKE_ARCH}" && \ + wget -q https://github.com/Kitware/CMake/releases/download/v3.30.4/cmake-3.30.4-linux-${CMAKE_ARCH}.tar.gz && \ + tar -xzf cmake-3.30.4-linux-${CMAKE_ARCH}.tar.gz && \ + mv cmake-3.30.4-linux-${CMAKE_ARCH} /opt/cmake && \ + ln -sf /opt/cmake/bin/* /usr/local/bin/ && \ + rm -rf cmake-3.30.4-linux-${CMAKE_ARCH}* && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + ARG NUM_MAKE_JOBS= ARG TARGETPLATFORM ARG TARGETARCH @@ -151,7 +172,7 @@ ADD dockerfile/etc /opt/microsoft/ WORKDIR ${SB_HOME} ADD third_party third_party -RUN make -C third_party cuda +RUN make -C third_party cuda cuda_nvbench ADD . . RUN python3 -m pip install --upgrade setuptools==78.1.0 && \ diff --git a/dockerfile/rocm5.0.x.dockerfile b/dockerfile/rocm5.0.x.dockerfile index 9ab35244c..e3e89cf3e 100644 --- a/dockerfile/rocm5.0.x.dockerfile +++ b/dockerfile/rocm5.0.x.dockerfile @@ -98,7 +98,7 @@ RUN cd /tmp && \ # Install Intel MLC RUN cd /tmp && \ - wget -q https://downloadmirror.intel.com/793041/mlc_v3.11.tgz -O mlc.tgz && \ + wget -q https://downloadmirror.intel.com/866182/mlc_v3.12.tgz -O mlc.tgz && \ tar xzf mlc.tgz Linux/mlc && \ cp ./Linux/mlc /usr/local/bin/ && \ rm -rf ./Linux mlc.tgz diff --git a/docs/user-tutorial/benchmarks/micro-benchmarks.md b/docs/user-tutorial/benchmarks/micro-benchmarks.md index aa3aa965b..a5bc2fa5c 100644 --- a/docs/user-tutorial/benchmarks/micro-benchmarks.md +++ b/docs/user-tutorial/benchmarks/micro-benchmarks.md @@ -172,6 +172,51 @@ Supports the use of double unit types and the use of tensor cores. | gpu-burn/gpu_[0-9]_pass | yes/no | The result of the gpu-burn test for each GPU (1: yes, 0: no). | | gpu-burn/abort | yes/no | Whether or not GPU-burn test aborted before returning GPU results (1: yes, 0: no). | +### `nvbench-sleep-kernel` + +#### Introduction + +Measure GPU kernel execution time using NVBench's sleep kernel benchmark. This benchmark creates CUDA kernels that sleep for specified durations (in microseconds) and measures the actual execution time, providing insights into GPU scheduling overhead and timing accuracy. + +The benchmark supports multiple duration specification formats: +- Single value: `"50"` - Test single duration of 50μs +- List format: `"[25,50,75]"` - Test multiple specific durations +- Range format: `"[25:75]"` - Test all values from 25μs to 75μs +- Range with step: `"[0:50:10]"` - Test from 0μs to 50μs in steps of 10μs + +Performed by [NVBench](https://github.com/NVIDIA/nvbench) sleep kernel benchmark. + +#### Metrics + +| Name | Unit | Description | +|-----------------------------------------|-----------|-------------------------------------------------------| +| nvbench-sleep-kernel/duration_us_{X}_cpu_time | time (μs) | CPU-measured time for duration X microseconds. | +| nvbench-sleep-kernel/duration_us_{X}_gpu_time | time (μs) | GPU-measured time for duration X microseconds. | +| nvbench-sleep-kernel/duration_us_{X}_batch_gpu_time | time (μs) | GPU batch execution time for duration X microseconds. | + +Where `{X}` is the sleep duration in microseconds (e.g., 25, 50, 75). + +### `nvbench-kernel-launch` + +#### Introduction + +Measure GPU kernel launch overhead and execution time using NVBench's kernel launch benchmark. This benchmark evaluates the time required to launch kernels on the GPU and measures both CPU-side and GPU-side timing for kernel execution. + +The benchmark provides insights into: +- Kernel launch latency +- CPU/GPU synchronization overhead +- Batch execution performance + +Performed by [NVBench](https://github.com/NVIDIA/nvbench) kernel launch benchmark. + +#### Metrics + +| Name | Unit | Description | +|-------------------------------------|-----------|------------------------------------------------| +| nvbench-kernel-launch/cpu_time | time (μs) | CPU-measured kernel execution time. | +| nvbench-kernel-launch/gpu_time | time (μs) | GPU-measured kernel execution time. | +| nvbench-kernel-launch/batch_gpu_time | time (μs) | GPU batch execution time. | + ### `cpu-hpl` #### Introduction diff --git a/examples/benchmarks/nvbench_kernel_launch.py b/examples/benchmarks/nvbench_kernel_launch.py new file mode 100644 index 000000000..c0f74f55a --- /dev/null +++ b/examples/benchmarks/nvbench_kernel_launch.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Example of NVBench Kernel Launch benchmark.""" + +from superbench.benchmarks import BenchmarkRegistry, Platform +from superbench.common.utils import logger + +if __name__ == '__main__': + context = BenchmarkRegistry.create_benchmark_context( + 'nvbench-kernel-launch', + platform=Platform.CUDA, + parameters=( + '--timeout 30 ' + '--min-samples 10 ' + '--min-time 1.0 ' + '--max-noise 0.1 ' + '--stopping-criterion stdrel ' + '--throttle-threshold 80 ' + '--throttle-recovery-delay 1.0' + ) + ) + + benchmark = BenchmarkRegistry.launch_benchmark(context) + if benchmark: + logger.info( + 'benchmark: {}, return code: {}, result: {}'.format( + benchmark.name, benchmark.return_code, benchmark.result + ) + ) diff --git a/examples/benchmarks/nvbench_sleep_kernel.py b/examples/benchmarks/nvbench_sleep_kernel.py new file mode 100644 index 000000000..083bd0a7c --- /dev/null +++ b/examples/benchmarks/nvbench_sleep_kernel.py @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Example of NVBench Sleep Kernel benchmark.""" + +from superbench.benchmarks import BenchmarkRegistry, Platform +from superbench.common.utils import logger + + +def main(): + """Main method to run the nvbench sleep kernel benchmark.""" + context = BenchmarkRegistry.create_benchmark_context( + 'nvbench-sleep-kernel', platform=Platform.CUDA, parameters='--duration_us "[25,50,75]" --timeout 10' + ) + + benchmark = BenchmarkRegistry.launch_benchmark(context) + if benchmark: + logger.info( + 'benchmark: {}, return code: {}, result: {}'.format( + benchmark.name, benchmark.return_code, benchmark.result + ) + ) + else: + logger.error('benchmark: nvbench-sleep-kernel launch failed.') + + +if __name__ == '__main__': + main() diff --git a/superbench/benchmarks/micro_benchmarks/__init__.py b/superbench/benchmarks/micro_benchmarks/__init__.py index 978c2d385..47a786f6d 100644 --- a/superbench/benchmarks/micro_benchmarks/__init__.py +++ b/superbench/benchmarks/micro_benchmarks/__init__.py @@ -39,42 +39,16 @@ from superbench.benchmarks.micro_benchmarks.directx_mem_bw_performance import DirectXGPUMemBw from superbench.benchmarks.micro_benchmarks.directx_gemm_flops_performance import DirectXGPUCoreFlops from superbench.benchmarks.micro_benchmarks.nvbandwidth import NvBandwidthBenchmark +from superbench.benchmarks.micro_benchmarks.nvbench_kernel_launch import NvbenchKernelLaunch +from superbench.benchmarks.micro_benchmarks.nvbench_sleep_kernel import NvbenchSleepKernel __all__ = [ - 'BlasLtBaseBenchmark', - 'ComputationCommunicationOverlap', - 'CpuMemBwLatencyBenchmark', - 'CpuHplBenchmark', - 'CpuStreamBenchmark', - 'CublasBenchmark', - 'CublasLtBenchmark', - 'CudaGemmFlopsBenchmark', - 'CudaMemBwBenchmark', - 'CudaNcclBwBenchmark', - 'CudnnBenchmark', - 'DiskBenchmark', - 'DistInference', - 'HipBlasLtBenchmark', - 'GPCNetBenchmark', - 'GemmFlopsBenchmark', - 'GpuBurnBenchmark', - 'GpuCopyBwBenchmark', - 'GpuStreamBenchmark', - 'IBBenchmark', - 'IBLoopbackBenchmark', - 'KernelLaunch', - 'MemBwBenchmark', - 'MicroBenchmark', - 'MicroBenchmarkWithInvoke', - 'ORTInferenceBenchmark', - 'RocmGemmFlopsBenchmark', - 'RocmMemBwBenchmark', - 'ShardingMatmul', - 'TCPConnectivityBenchmark', - 'TensorRTInferenceBenchmark', - 'DirectXGPUEncodingLatency', - 'DirectXGPUCopyBw', - 'DirectXGPUMemBw', - 'DirectXGPUCoreFlops', - 'NvBandwidthBenchmark', + 'BlasLtBaseBenchmark', 'ComputationCommunicationOverlap', 'CpuMemBwLatencyBenchmark', 'CpuHplBenchmark', + 'CpuStreamBenchmark', 'CublasBenchmark', 'CublasLtBenchmark', 'CudaGemmFlopsBenchmark', 'CudaMemBwBenchmark', + 'CudaNcclBwBenchmark', 'CudnnBenchmark', 'DiskBenchmark', 'DistInference', 'HipBlasLtBenchmark', 'GPCNetBenchmark', + 'GemmFlopsBenchmark', 'GpuBurnBenchmark', 'GpuCopyBwBenchmark', 'GpuStreamBenchmark', 'IBBenchmark', + 'IBLoopbackBenchmark', 'KernelLaunch', 'MemBwBenchmark', 'MicroBenchmark', 'MicroBenchmarkWithInvoke', + 'ORTInferenceBenchmark', 'RocmGemmFlopsBenchmark', 'RocmMemBwBenchmark', 'ShardingMatmul', + 'TCPConnectivityBenchmark', 'TensorRTInferenceBenchmark', 'DirectXGPUEncodingLatency', 'DirectXGPUCopyBw', + 'DirectXGPUMemBw', 'DirectXGPUCoreFlops', 'NvBandwidthBenchmark', 'NvbenchKernelLaunch', 'NvbenchSleepKernel' ] diff --git a/superbench/benchmarks/micro_benchmarks/nvbench/CMakeLists.txt b/superbench/benchmarks/micro_benchmarks/nvbench/CMakeLists.txt new file mode 100644 index 000000000..8415b10b4 --- /dev/null +++ b/superbench/benchmarks/micro_benchmarks/nvbench/CMakeLists.txt @@ -0,0 +1,45 @@ +cmake_minimum_required(VERSION 3.18) +project(nvbench_benchmarks LANGUAGES CUDA) + +# Check if we have a recent enough CMake for nvbench (which requires 3.30.4) +if(CMAKE_VERSION VERSION_LESS "3.30.4") + message(STATUS "CMake version ${CMAKE_VERSION} is less than 3.30.4 (required by nvbench), skipping nvbench benchmarks") + return() +endif() + +find_package(CUDAToolkit QUIET) +if (CUDAToolkit_FOUND) + include(../cuda_common.cmake) + + # Try to find nvbench, but don't require it + find_package(nvbench CONFIG QUIET) + + if (nvbench_FOUND) + message(STATUS "Found nvbench, building nvbench benchmarks") + + # list all your CUDA benchmark source files here + set(NVBENCH_SOURCES + kernel_launch.cu + sleep_kernel.cu + # add more *.cu as needed + ) + + foreach(src ${NVBENCH_SOURCES}) + # strip ".cu" → NAME_WE + get_filename_component(basename ${src} NAME_WE) + set(target nvbench_${basename}) + + add_executable(${target} ${src}) + target_compile_features(${target} PUBLIC cuda_std_17) + target_link_libraries(${target} + PRIVATE nvbench::nvbench nvbench::main + ) + install(TARGETS ${target} RUNTIME DESTINATION bin) + endforeach() + else() + message(STATUS "nvbench not found, skipping nvbench benchmarks.") + message(STATUS "To build nvbench benchmarks, first build the submodule in third_party/nvbench") + endif() +else() + message(STATUS "CUDA not found, skipping nvbench benchmarks.") +endif() \ No newline at end of file diff --git a/superbench/benchmarks/micro_benchmarks/nvbench/kernel_launch.cu b/superbench/benchmarks/micro_benchmarks/nvbench/kernel_launch.cu new file mode 100644 index 000000000..08dc40294 --- /dev/null +++ b/superbench/benchmarks/micro_benchmarks/nvbench/kernel_launch.cu @@ -0,0 +1,9 @@ +#include + +__global__ void empty_kernel() {} + +void kernel_launch(nvbench::state &state) { + state.exec([](nvbench::launch &launch) { empty_kernel<<<1, 1, 0, launch.get_stream()>>>(); }); +} + +NVBENCH_BENCH(kernel_launch); \ No newline at end of file diff --git a/superbench/benchmarks/micro_benchmarks/nvbench/sleep_kernel.cu b/superbench/benchmarks/micro_benchmarks/nvbench/sleep_kernel.cu new file mode 100644 index 000000000..b4789377e --- /dev/null +++ b/superbench/benchmarks/micro_benchmarks/nvbench/sleep_kernel.cu @@ -0,0 +1,22 @@ +#include +#include +#include + +__global__ void sleep_kernel(nvbench::int64_t microseconds) { + const auto start = cuda::std::chrono::high_resolution_clock::now(); + const auto target_duration = cuda::std::chrono::microseconds(microseconds); + const auto finish = start + target_duration; + + while (cuda::std::chrono::high_resolution_clock::now() < finish) { + // busy wait + } +} + +void sleep_benchmark(nvbench::state &state) { + const auto duration_us = state.get_int64("Duration (us)"); + state.exec( + [&duration_us](nvbench::launch &launch) { sleep_kernel<<<1, 1, 0, launch.get_stream()>>>(duration_us); }); +} +NVBENCH_BENCH(sleep_benchmark) + .add_int64_axis("Duration (us)", nvbench::range(0, 100, 5)) + .set_timeout(1); // Limit to one second per measurement. \ No newline at end of file diff --git a/superbench/benchmarks/micro_benchmarks/nvbench_base.py b/superbench/benchmarks/micro_benchmarks/nvbench_base.py new file mode 100644 index 000000000..98e705b46 --- /dev/null +++ b/superbench/benchmarks/micro_benchmarks/nvbench_base.py @@ -0,0 +1,255 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Base class for NVBench benchmarks.""" + +import os +import re +from superbench.common.utils import logger +from superbench.benchmarks import ReturnCode +from superbench.benchmarks.micro_benchmarks.micro_base import MicroBenchmarkWithInvoke + + +def parse_time_to_us(raw: str) -> float: + """Helper: parse '123.45 us', '678.9 ns', '0.12 ms' → float µs.""" + raw = raw.strip() + if raw.endswith('%'): + return float(raw[:-1]) + # split "value unit" or "valueunit" + m = re.match(r'([\d.]+)\s*([mun]?s)?', raw) + if not m: + return float(raw) + val, unit = float(m.group(1)), (m.group(2) or 'us') + if unit == 'ns': + return val / 1e3 + if unit == 'ms': + return val * 1e3 + return val + + +class NvbenchBase(MicroBenchmarkWithInvoke): + """Base class for NVBench benchmarks with common functionality.""" + def __init__(self, name, parameters=''): + """Constructor. + + Args: + name (str): benchmark name. + parameters (str): benchmark parameters. + """ + super().__init__(name, parameters) + # Subclasses should set this + self._bin_name = None + + def add_parser_arguments(self): + """Add common NVBench arguments.""" + super().add_parser_arguments() + + # Device configuration + self._parser.add_argument( + '--devices', + type=str, + default=None, + help='Device list to run the benchmark, e.g., "0,1,2,3" or "all".', + ) + + # Benchmark Properties + self._parser.add_argument( + '--skip-time', + type=float, + default=-1.0, + help='Skip time in seconds.', + ) + self._parser.add_argument( + '--throttle-threshold', + type=float, + default=75.0, + help='Throttle threshold percentage.', + ) + self._parser.add_argument( + '--throttle-recovery-delay', + type=float, + default=0.05, + help='Throttle recovery delay in seconds.', + ) + self._parser.add_argument( + '--run-once', + action='store_true', + help='Run once flag.', + ) + self._parser.add_argument( + '--disable-blocking-kernel', + action='store_true', + help='Disable blocking kernel flag.', + ) + self._parser.add_argument( + '--profile', + action='store_true', + help='Enable profiling flag.', + ) + + # Stopping Criteria + self._parser.add_argument( + '--timeout', + type=int, + default=15, + help='Timeout in seconds.', + ) + self._parser.add_argument( + '--min-samples', + type=int, + default=10, + help='Minimum number of samples.', + ) + self._parser.add_argument( + '--stopping-criterion', + type=str, + default='stdrel', + choices=['stdrel', 'entropy'], + help='Stopping criterion.', + ) + # stdrel-specific + self._parser.add_argument( + '--min-time', + type=float, + default=0.5, + help='Minimum time for stdrel stopping criterion.', + ) + self._parser.add_argument( + '--max-noise', + type=float, + default=0.5, + help='Maximum noise for stdrel stopping criterion.', + ) + # entropy-specific + self._parser.add_argument( + '--max-angle', + type=float, + default=0.048, + help='Maximum angle for entropy stopping criterion.', + ) + self._parser.add_argument( + '--min-r2', + type=float, + default=0.36, + help='Minimum R-squared for entropy stopping criterion.', + ) + + def _add_device_args(self, parts): + """Add device configuration arguments to command parts.""" + if hasattr(self._args, 'devices') and self._args.devices is not None: + if self._args.devices == 'all': + parts.extend(['--devices', 'all']) + else: + parts.extend(['--devices', self._args.devices]) + + def _add_benchmark_property_args(self, parts): + """Add benchmark property arguments to command parts.""" + if hasattr(self._args, 'skip_time') and self._args.skip_time >= 0: + parts.extend(['--skip-time', str(self._args.skip_time)]) + if hasattr(self._args, 'throttle_threshold') and self._args.throttle_threshold > 0: + parts.extend(['--throttle-threshold', str(self._args.throttle_threshold)]) + if hasattr(self._args, 'throttle_recovery_delay') and self._args.throttle_recovery_delay > 0: + parts.extend(['--throttle-recovery-delay', str(self._args.throttle_recovery_delay)]) + if hasattr(self._args, 'run_once') and self._args.run_once: + parts.append('--run-once') + if hasattr(self._args, 'disable_blocking_kernel') and self._args.disable_blocking_kernel: + parts.append('--disable-blocking-kernel') + if hasattr(self._args, 'profile') and self._args.profile: + parts.append('--profile') + + def _add_stopping_criteria_args(self, parts): + """Add stopping criteria arguments to command parts.""" + if hasattr(self._args, 'timeout') and self._args.timeout is not None: + parts.extend(['--timeout', str(self._args.timeout)]) + if hasattr(self._args, 'min_samples') and self._args.min_samples is not None: + parts.extend(['--min-samples', str(self._args.min_samples)]) + if hasattr(self._args, 'stopping_criterion') and self._args.stopping_criterion: + parts.extend(['--stopping-criterion', self._args.stopping_criterion]) + if self._args.stopping_criterion == 'stdrel': + self._add_stdrel_args(parts) + elif self._args.stopping_criterion == 'entropy': + self._add_entropy_args(parts) + + def _add_stdrel_args(self, parts): + """Add stdrel-specific stopping criterion arguments.""" + if hasattr(self._args, 'min_time') and self._args.min_time is not None: + parts.extend(['--min-time', str(self._args.min_time)]) + if hasattr(self._args, 'max_noise') and self._args.max_noise is not None: + parts.extend(['--max-noise', str(self._args.max_noise)]) + + def _add_entropy_args(self, parts): + """Add entropy-specific stopping criterion arguments.""" + if hasattr(self._args, 'max_angle') and self._args.max_angle is not None: + parts.extend(['--max-angle', str(self._args.max_angle)]) + if hasattr(self._args, 'min_r2') and self._args.min_r2 is not None: + parts.extend(['--min-r2', str(self._args.min_r2)]) + + def _build_base_command(self): + """Build the base nvbench command with common arguments. + + Returns: + list: Command parts that can be extended by subclasses. + """ + if not self._bin_name: + raise ValueError('Subclass must set _bin_name') + + command = os.path.join(self._args.bin_dir, self._bin_name) + parts = [command] + + self._add_device_args(parts) + self._add_benchmark_property_args(parts) + self._add_stopping_criteria_args(parts) + + return parts + + def _preprocess(self): + """Default preprocess implementation. Can be overridden by subclasses. + + Returns: + True if _preprocess() succeed. + """ + if not super()._preprocess(): + return False + + # Build base command - subclasses can override this method to add specific arguments + parts = self._build_base_command() + + # Finalize command + self._commands = [' '.join(parts)] + return True + + def _parse_time_value(self, time_str): + """Parse time string to microseconds. + + Args: + time_str (str): Time string like '123.45 us', '678.9 ns', etc. + + Returns: + float: Time in microseconds. + """ + return parse_time_to_us(time_str) + + def _parse_percentage(self, percent_str): + """Parse percentage string to float. + + Args: + percent_str (str): Percentage string like '12.34%' + + Returns: + float: Percentage value as float. + """ + if isinstance(percent_str, str) and percent_str.endswith('%'): + return float(percent_str[:-1]) + return float(percent_str) + + def _handle_parsing_error(self, error_msg, raw_output): + """Handle parsing errors consistently. + + Args: + error_msg (str): Error message to log. + raw_output (str): Raw output that failed to parse. + """ + self._result.set_return_code(ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE) + logger.error( + f'Invalid result format - round:{self._curr_run_index}, bench:{self._name}, msg:{error_msg}\n{raw_output}' + ) diff --git a/superbench/benchmarks/micro_benchmarks/nvbench_kernel_launch.py b/superbench/benchmarks/micro_benchmarks/nvbench_kernel_launch.py new file mode 100644 index 000000000..5120f1b51 --- /dev/null +++ b/superbench/benchmarks/micro_benchmarks/nvbench_kernel_launch.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Module of the NVBench Kernel Launch benchmark.""" + +import re +from superbench.common.utils import logger +from superbench.benchmarks import BenchmarkRegistry, Platform +from superbench.benchmarks.micro_benchmarks.nvbench_base import NvbenchBase + + +class NvbenchKernelLaunch(NvbenchBase): + """The NVBench Kernel Launch benchmark class.""" + def __init__(self, name, parameters=''): + """Constructor. + + Args: + name (str): benchmark name. + parameters (str): benchmark parameters. + """ + super().__init__(name, parameters) + self._bin_name = 'nvbench_kernel_launch' + + def _process_raw_result(self, cmd_idx, raw_output): + """Function to parse raw results and save the summarized results. + + Args: + cmd_idx (int): the index of command corresponding with the raw_output. + raw_output (str): raw output string of the micro-benchmark. + + Return: + True if the raw output string is valid and result can be extracted. + """ + self._result.add_raw_data(f'raw_output_{cmd_idx}', raw_output, self._args.log_raw_data) + + try: + gpu_section = r'### \[(\d+)\] NVIDIA' + # Regex pattern to handle different time units and flexible spacing + row_pat = ( + r'\|\s*([0-9]+)x\s*\|\s*' # Samples + r'([\d.]+\s*[μmun]?s)\s*\|\s*' # CPU Time (μs, ns, ms, us, s) + r'([\d.]+%)\s*\|\s*' # CPU Noise percentage + r'([\d.]+\s*[μmun]?s)\s*\|\s*' # GPU Time + r'([\d.]+%)\s*\|\s*' # GPU Noise percentage + r'([0-9]+)x\s*\|\s*' # Batch Samples + r'([\d.]+\s*[μmun]?s)\s*\|' # Batch GPU Time + ) + current = None + parsed_any = False # Track if any valid rows are parsed + + for line in raw_output.splitlines(): + line = line.strip() + g = re.match(gpu_section, line) + if g: + current = f'gpu_{g.group(1)}' + continue + + r = re.match(row_pat, line) + if r and current: + samples, cpu_time, cpu_noise, gpu_time, gpu_noise, batch_samples, batch_gpu = r.groups() + # self._result.add_result('samples', int(samples.replace('x', ''))) + self._result.add_result('cpu_time', self._parse_time_value(cpu_time)) + # self._result.add_result('cpu_noise', self._parse_percentage(cpu_noise)) + self._result.add_result('gpu_time', self._parse_time_value(gpu_time)) + # self._result.add_result('gpu_noise', self._parse_percentage(gpu_noise)) + # self._result.add_result('batch_samples', int(batch_samples.replace('x', ''))) + self._result.add_result('batch_gpu_time', self._parse_time_value(batch_gpu)) + parsed_any = True + + if not parsed_any: + logger.error('No valid rows parsed from the raw output.') + raise RuntimeError('No valid rows parsed') + + except Exception as e: + self._handle_parsing_error(str(e), raw_output) + return False + + return True + + +BenchmarkRegistry.register_benchmark('nvbench-kernel-launch', NvbenchKernelLaunch, platform=Platform.CUDA) diff --git a/superbench/benchmarks/micro_benchmarks/nvbench_sleep_kernel.py b/superbench/benchmarks/micro_benchmarks/nvbench_sleep_kernel.py new file mode 100644 index 000000000..e7bcb4322 --- /dev/null +++ b/superbench/benchmarks/micro_benchmarks/nvbench_sleep_kernel.py @@ -0,0 +1,121 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Module of the NVBench Sleep Kernel benchmark.""" + +import re +from superbench.common.utils import logger +from superbench.benchmarks import BenchmarkRegistry, Platform, ReturnCode +from superbench.benchmarks.micro_benchmarks.nvbench_base import NvbenchBase + + +class NvbenchSleepKernel(NvbenchBase): + """The NVBench Sleep Kernel benchmark class.""" + def __init__(self, name, parameters=''): + """Constructor. + + Args: + name (str): benchmark name. + parameters (str): benchmark parameters. + """ + super().__init__(name, parameters) + + self._bin_name = 'nvbench_sleep_kernel' + + def add_parser_arguments(self): + """Add sleep-kernel specific arguments.""" + super().add_parser_arguments() + + # Sleep-kernel specific argument + self._parser.add_argument( + '--duration_us', + type=str, + default='[0,25,50,75,100]', + help='Duration axis values in microseconds. Supports multiple formats: ' + '"50" (single value), "[25,50,75]" (list), "[25:75]" (range), "[0:50:10]" (range with step).', + ) + + def _preprocess(self): + """Preprocess/preparation operations before the benchmarking. + + Return: + True if _preprocess() succeed. + """ + if not super()._preprocess(): + return False + + # Build base command with common nvbench arguments + parts = self._build_base_command() + + # Add sleep-kernel specific arguments + parts.extend(['--axis', f'"Duration (us)={self._args.duration_us.strip()}"']) + + # Finalize command + self._commands = [' '.join(parts)] + return True + + def _process_raw_result(self, cmd_idx, raw_output): + """Function to parse raw results and save the summarized results. + + self._result.add_raw_data() and self._result.add_result() need to be called to save the results. + + Args: + cmd_idx (int): the index of command corresponding with the raw_output. + raw_output (str): raw output string of the micro-benchmark. + + Return: + True if the raw output string is valid and result can be extracted. + """ + logger.debug(f'Processing raw result for command index {cmd_idx}.') + logger.debug(f'Raw output:\n{raw_output}') + + self._result.add_raw_data(f'raw_output_{cmd_idx}', raw_output, self._args.log_raw_data) + try: + gpu_section = r'### \[(\d+)\] NVIDIA' + # Regex pattern to handle different time units and flexible spacing + row_pat = ( + r'\|\s*([0-9]+)\s*\|\s*' # Duration (us) + r'([0-9]+)x\s*\|\s*' # Samples + r'([\d.]+\s*[μmun]?s)\s*\|\s*' # CPU Time (μs, ns, ms, us, s) + r'([\d.]+%)\s*\|\s*' # CPU Noise percentage + r'([\d.]+\s*[μmun]?s)\s*\|\s*' # GPU Time + r'([\d.]+%)\s*\|\s*' # GPU Noise percentage + r'([0-9]+)x\s*\|\s*' # Batch Samples + r'([\d.]+\s*[μmun]?s)\s*\|' # Batch GPU Time + ) + current = None + parsed_any = False + for line in raw_output.splitlines(): + line = line.strip() + logger.debug(f'Processing line: {line}') + g = re.match(gpu_section, line) + if g: + current = f'gpu_{g.group(1)}' + logger.debug(f'Found GPU section: {current}') + continue + r = re.match(row_pat, line) + if r and current: + logger.debug(f'Matched row: {r.groups()}') + duration_us, samples, cpu_time, cpu_noise, gpu_time, gpu_noise, batch_samples, batch_gpu = r.groups( + ) + # self._result.add_result(f'duration_us_{duration_us}_samples', int(samples)) + self._result.add_result(f'duration_us_{duration_us}_cpu_time', self._parse_time_value(cpu_time)) + # self._result.add_result(f'duration_us_{duration_us}_cpu_noise', self._parse_percentage(cpu_noise)) + self._result.add_result(f'duration_us_{duration_us}_gpu_time', self._parse_time_value(gpu_time)) + # self._result.add_result(f'duration_us_{duration_us}_gpu_noise', self._parse_percentage(gpu_noise)) + # self._result.add_result(f'duration_us_{duration_us}_batch_samples', + # int(batch_samples.replace('x', ''))) + self._result.add_result( + f'duration_us_{duration_us}_batch_gpu_time', self._parse_time_value(batch_gpu) + ) + parsed_any = True + if not parsed_any: + raise RuntimeError('No valid rows parsed') + except Exception as e: + logger.error(f'Error processing raw result: {e}') + self._result.set_return_code(ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE) + return False + return True + + +BenchmarkRegistry.register_benchmark('nvbench-sleep-kernel', NvbenchSleepKernel, platform=Platform.CUDA) diff --git a/tests/benchmarks/micro_benchmarks/test_nvbench_kernel_launch.py b/tests/benchmarks/micro_benchmarks/test_nvbench_kernel_launch.py new file mode 100644 index 000000000..02908e5eb --- /dev/null +++ b/tests/benchmarks/micro_benchmarks/test_nvbench_kernel_launch.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for nvbench kernel launch benchmark.""" + +import unittest + +from tests.helper import decorator +from tests.helper.testcase import BenchmarkTestCase +from superbench.benchmarks import BenchmarkRegistry, ReturnCode, Platform + + +class TestNvbenchKernelLaunchBenchmark(BenchmarkTestCase, unittest.TestCase): + """Test class for NVBench Kernel Launch benchmark.""" + @classmethod + def setUpClass(cls): + """Hook method for setting up class fixture before running tests in the class.""" + super().setUpClass() + cls.createMockEnvs(cls) + cls.createMockFiles(cls, ['bin/nvbench_kernel_launch']) + + def test_nvbench_kernel_launch_preprocess(self): + """Test NVBench Kernel Launch benchmark preprocess.""" + benchmark_name = 'nvbench-kernel-launch' + (benchmark_class, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, Platform.CUDA) + assert (benchmark_class) + + # Test preprocess with default parameters + benchmark = benchmark_class(benchmark_name, parameters='') + assert benchmark._preprocess() + assert benchmark.return_code == ReturnCode.SUCCESS + + # Test preprocess with specified parameters + parameters = ( + '--devices 0 ' + '--timeout 20 ' + '--min-samples 300 ' + '--stopping-criterion stdrel ' + '--min-time 2.0 ' + '--max-noise 0.5 ' + '--throttle-threshold 80.0 ' + '--throttle-recovery-delay 1.0' + ) + benchmark = benchmark_class(benchmark_name, parameters=parameters) + assert benchmark._preprocess() + assert benchmark.return_code == ReturnCode.SUCCESS + + # Check command + assert (1 == len(benchmark._commands)) + assert ('--devices 0' in benchmark._commands[0]) + assert ('--timeout 20' in benchmark._commands[0]) + assert ('--min-samples 300' in benchmark._commands[0]) + assert ('--stopping-criterion stdrel' in benchmark._commands[0]) + assert ('--min-time 2.0' in benchmark._commands[0]) + assert ('--max-noise 0.5' in benchmark._commands[0]) + assert ('--throttle-threshold 80.0' in benchmark._commands[0]) + assert ('--throttle-recovery-delay 1.0' in benchmark._commands[0]) + + @decorator.load_data('tests/data/nvbench_kernel_launch.log') + def test_nvbench_kernel_launch_result_parsing_real_output(self, results): + """Test NVBench Kernel Launch benchmark result parsing.""" + benchmark_name = 'nvbench-kernel-launch' + (benchmark_class, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, Platform.CUDA) + assert (benchmark_class) + + benchmark = benchmark_class(benchmark_name, parameters='') + + # Preprocess and validate command + assert benchmark._preprocess() + + # Parse the provided raw output + assert benchmark._process_raw_result(0, results) + assert benchmark.return_code == ReturnCode.SUCCESS + + # Validate parsed results + # assert benchmark.result['samples'][0] == 120000 + assert benchmark.result['cpu_time'][0] == 24.222 + # assert benchmark.result['cpu_noise'][0] == 30.44 + assert benchmark.result['gpu_time'][0] == 7.808 + # assert benchmark.result['gpu_noise'][0] == 14.42 + # assert benchmark.result['batch_samples'][0] == 300000 + assert benchmark.result['batch_gpu_time'][0] == 6.024 + + def test_nvbench_kernel_launch_process_raw_result_invalid_output(self): + """Test NVBench Kernel Launch benchmark result parsing with invalid output.""" + benchmark_name = 'nvbench-kernel-launch' + (benchmark_class, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, Platform.CUDA) + assert (benchmark_class) + + benchmark = benchmark_class(benchmark_name, parameters='') + + # Preprocess and validate command + assert benchmark._preprocess() + + # Mock raw output with invalid format + raw_output = 'Invalid output format' + + # Parse the provided raw output + assert not benchmark._process_raw_result(0, raw_output) + assert benchmark.return_code == ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/benchmarks/micro_benchmarks/test_nvbench_sleep_kernel.py b/tests/benchmarks/micro_benchmarks/test_nvbench_sleep_kernel.py new file mode 100644 index 000000000..4606768ff --- /dev/null +++ b/tests/benchmarks/micro_benchmarks/test_nvbench_sleep_kernel.py @@ -0,0 +1,140 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for nvbench sleep kernel benchmark.""" + +import unittest + +from tests.helper import decorator +from tests.helper.testcase import BenchmarkTestCase +from superbench.benchmarks import BenchmarkRegistry, ReturnCode, Platform + + +class TestNvbenchSleepKernelBenchmark(BenchmarkTestCase, unittest.TestCase): + """Test class for NVBench Sleep Kernel benchmark.""" + @classmethod + def setUpClass(cls): + """Hook method for setting up class fixture before running tests in the class.""" + super().setUpClass() + cls.createMockEnvs(cls) + cls.createMockFiles(cls, ['bin/nvbench_sleep_kernel']) + + def test_nvbench_sleep_kernel_preprocess(self): + """Test NVBench Sleep Kernel benchmark preprocess.""" + benchmark_name = 'nvbench-sleep-kernel' + (benchmark_class, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, Platform.CUDA) + assert (benchmark_class) + + # Test preprocess with default parameters + benchmark = benchmark_class(benchmark_name, parameters='') + assert benchmark._preprocess() + assert benchmark.return_code == ReturnCode.SUCCESS + + # Test preprocess with specified parameters + parameters = ( + '--devices 0 ' + '--duration_us "[10,25,50,75]" ' + '--timeout 20 ' + '--min-samples 300 ' + '--stopping-criterion stdrel ' + '--min-time 2.0 ' + '--max-noise 0.5 ' + '--throttle-threshold 80.0 ' + '--throttle-recovery-delay 1.0' + ) + benchmark = benchmark_class(benchmark_name, parameters=parameters) + assert benchmark._preprocess() + assert benchmark.return_code == ReturnCode.SUCCESS + + # Check command + assert (1 == len(benchmark._commands)) + assert ('--devices 0' in benchmark._commands[0]) + assert ('--axis "Duration (us)=[10,25,50,75]"' in benchmark._commands[0]) + assert ('--timeout 20' in benchmark._commands[0]) + assert ('--min-samples 300' in benchmark._commands[0]) + assert ('--stopping-criterion stdrel' in benchmark._commands[0]) + assert ('--min-time 2.0' in benchmark._commands[0]) + assert ('--max-noise 0.5' in benchmark._commands[0]) + assert ('--throttle-threshold 80.0' in benchmark._commands[0]) + assert ('--throttle-recovery-delay 1.0' in benchmark._commands[0]) + + @decorator.load_data('tests/data/nvbench_sleep_kernel.log') + def test_nvbench_sleep_kernel_result_parsing_real_output(self, results): + """Test NVBench Sleep Kernel benchmark result parsing.""" + benchmark_name = 'nvbench-sleep-kernel' + (benchmark_class, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, Platform.CUDA) + assert (benchmark_class) + + benchmark = benchmark_class(benchmark_name, parameters='') + + # Preprocess and validate command + assert benchmark._preprocess() + + # Parse the provided raw output + assert benchmark._process_raw_result(0, results) + assert benchmark.return_code == ReturnCode.SUCCESS + + # Validate parsed results + # assert benchmark.result['duration_us_25_samples'][0] == 10175 + assert benchmark.result['duration_us_25_cpu_time'][0] == 42.123 + # assert benchmark.result['duration_us_25_cpu_noise'][0] == 69.78 + assert benchmark.result['duration_us_25_gpu_time'][0] == 25.321 + # assert benchmark.result['duration_us_25_gpu_noise'][0] == 0.93 + # assert benchmark.result['duration_us_25_batch_samples'][0] == 17448 + assert benchmark.result['duration_us_25_batch_gpu_time'][0] == 23.456 + + # assert benchmark.result['duration_us_50_samples'][0] == 8187 + # assert benchmark.result['duration_us_75_samples'][0] == 6279 + + def test_nvbench_sleep_kernel_preprocess_duration_formats(self): + """Test NVBench Sleep Kernel preprocess with different duration formats.""" + benchmark_name = 'nvbench-sleep-kernel' + (benchmark_class, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, Platform.CUDA) + assert (benchmark_class) + + # Test single value + benchmark = benchmark_class(benchmark_name, parameters='--duration_us "50"') + assert benchmark._preprocess() + assert '--axis "Duration (us)=50"' in benchmark._commands[0] + + # Test list format + benchmark = benchmark_class(benchmark_name, parameters='--duration_us "[25,50,75]"') + assert benchmark._preprocess() + assert '--axis "Duration (us)=[25,50,75]"' in benchmark._commands[0] + + # Test range format + benchmark = benchmark_class(benchmark_name, parameters='--duration_us "[25:75]"') + assert benchmark._preprocess() + assert '--axis "Duration (us)=[25:75]"' in benchmark._commands[0] + + # Test range with step format + benchmark = benchmark_class(benchmark_name, parameters='--duration_us "[0:50:10]"') + assert benchmark._preprocess() + assert '--axis "Duration (us)=[0:50:10]"' in benchmark._commands[0] + + # Test default format + benchmark = benchmark_class(benchmark_name, parameters='') + assert benchmark._preprocess() + assert '--axis "Duration (us)=[0,25,50,75,100]"' in benchmark._commands[0] + + def test_nvbench_sleep_kernel_process_raw_result_invalid_output(self): + """Test NVBench Sleep Kernel benchmark result parsing with invalid output.""" + benchmark_name = 'nvbench-sleep-kernel' + (benchmark_class, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(benchmark_name, Platform.CUDA) + assert (benchmark_class) + + benchmark = benchmark_class(benchmark_name, parameters='') + + # Preprocess and validate command + assert benchmark._preprocess() + + # Mock raw output with invalid format + raw_output = 'Invalid output format' + + # Parse the provided raw output + assert not benchmark._process_raw_result(0, raw_output) + assert benchmark.return_code == ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/data/nvbench_kernel_launch.log b/tests/data/nvbench_kernel_launch.log new file mode 100644 index 000000000..95ccbc065 --- /dev/null +++ b/tests/data/nvbench_kernel_launch.log @@ -0,0 +1,34 @@ +# Devices + +## [0] `NVIDIA GPU` +* SM Version: 900 (PTX Version: 900) +* Number of SMs: 100 +* SM Default Clock Rate: 800 MHz +* Global Memory: 100000 MiB Free / 100000 MiB Total +* Global Memory Bus Peak: 3000 GB/sec (3000-bit DDR @4000MHz) +* Max Shared Memory: 100 KiB/SM, 20 KiB/Block +* L2 Cache Size: 1000 KiB +* Maximum Active Blocks: 10/SM +* Maximum Active Threads: 500/SM, 300/Block +* Available Registers: 500/SM, 500/Block +* ECC Enabled: Yes + +# Log + +``` +Run: [1/1] launch_bench [Device=0] +Warn: Current measurement timed out (20.00s) while over noise threshold (14.42% > 0.50%) +Warn: Current measurement timed out (20.00s) before accumulating min_time (0.94s < 2.00s) +Pass: Cold: 0.005878ms GPU, 0.022181ms CPU, 0.94s total GPU, 20.00s total wall, 120000x +Pass: Batch: 0.004024ms GPU, 2.00s total GPU, 2.00s total wall, 300000x +``` + +# Benchmark Results + +## launch_bench + +### [0] NVIDIA GPU + +| Samples | CPU Time | Noise | GPU Time | Noise | Samples | Batch GPU | +|---------|-----------|--------|----------|--------|---------|-----------| +| 120000x | 24.222 us | 30.44% | 7.808 us | 14.42% | 300000x | 6.024 us | \ No newline at end of file diff --git a/tests/data/nvbench_sleep_kernel.log b/tests/data/nvbench_sleep_kernel.log new file mode 100644 index 000000000..b2c4037da --- /dev/null +++ b/tests/data/nvbench_sleep_kernel.log @@ -0,0 +1,43 @@ +# Devices + +## [0] `NVIDIA GPU` +* SM Version: 900 (PTX Version: 900) +* Number of SMs: 100 +* SM Default Clock Rate: 800 MHz +* Global Memory: 100000 MiB Free / 100000 MiB Total +* Global Memory Bus Peak: 3000 GB/sec (3000-bit DDR @4000MHz) +* Max Shared Memory: 100 KiB/SM, 20 KiB/Block +* L2 Cache Size: 1000 KiB +* Maximum Active Blocks: 10/SM +* Maximum Active Threads: 500/SM, 300/Block +* Available Registers: 500/SM, 500/Block +* ECC Enabled: Yes + +# Log + +``` +Run: [1/3] sleep_benchmark [Device=0 Duration (us)=25] +Warn: Current measurement timed out (1.00s) while over noise threshold (0.93% > 0.50%) +Warn: Current measurement timed out (1.00s) before accumulating min_time (0.31s < 0.50s) +Pass: Cold: 0.030374ms GPU, 0.047379ms CPU, 0.31s total GPU, 1.00s total wall, 10175x +Pass: Batch: 0.028658ms GPU, 0.50s total GPU, 0.50s total wall, 17448x +Run: [2/3] sleep_benchmark [Device=0 Duration (us)=50] +Warn: Current measurement timed out (1.00s) before accumulating min_time (0.45s < 0.50s) +Pass: Cold: 0.055036ms GPU, 0.072054ms CPU, 0.45s total GPU, 1.00s total wall, 8187x +Pass: Batch: 0.053246ms GPU, 0.50s total GPU, 0.50s total wall, 9403x +Run: [3/3] sleep_benchmark [Device=0 Duration (us)=75] +Pass: Cold: 0.079643ms GPU, 0.096788ms CPU, 0.50s total GPU, 0.92s total wall, 6279x +Pass: Batch: 0.077862ms GPU, 0.51s total GPU, 0.51s total wall, 6547x +``` + +# Benchmark Results + +## sleep_benchmark + +### [0] NVIDIA GPU + +| Duration (us) | Samples | CPU Time | Noise | GPU Time | Noise | Samples | Batch GPU | +|---------------|---------|-----------|--------|-----------|-------|---------|-----------| +| 25 | 10175x | 42.123 us | 69.78% | 25.321 us | 0.93% | 17448x | 23.456 us | +| 50 | 8187x | 68.456 us | 2.34% | 50.654 us | 0.45% | 9403x | 49.321 us | +| 75 | 6279x | 90.789 us | 1.85% | 75.987 us | 0.33% | 6547x | 77.862 us | \ No newline at end of file diff --git a/third_party/Makefile b/third_party/Makefile index 2a09f5990..b25fca042 100755 --- a/third_party/Makefile +++ b/third_party/Makefile @@ -191,7 +191,7 @@ endif cpu_hpl: sb_micro_path ifneq (,$(wildcard hpl-tests/Makefile)) cd ./hpl-tests && \ - wget https://netlib.org/benchmark/hpl/hpl-2.3.tar.gz && \ + wget https://netlib.org/benchmark/hpl/hpl-2.3.tar.gz && \ tar xzf hpl-2.3.tar.gz && \ cp Make.Linux_zen3 hpl-2.3 && \ cp Make.Linux_zen4 hpl-2.3 && \ @@ -207,7 +207,7 @@ endif cpu_stream: sb_micro_path ifneq (,$(wildcard stream-tests/Makefile)) cd ./stream-tests && \ - wget https://www.cs.virginia.edu/stream/FTP/Code/stream.c && \ + wget https://www.cs.virginia.edu/stream/FTP/Code/stream.c && \ make all cp -v ./stream-tests/stream* $(SB_MICRO_PATH)/bin/ endif @@ -242,10 +242,10 @@ rocm_megatron_lm: fi cp Megatron/rocm/Megatron-LM/examples/deepseek_v2/pretrain_deepseek.py Megatron/rocm/Megatron-LM/ git clone https://github.com/caaatch22/grouped_gemm.git &&\ - cd grouped_gemm &&\ - git checkout 8a9b438 &&\ - git submodule update --init --recursive &&\ - pip install . + cd grouped_gemm &&\ + git checkout 8a9b438 &&\ + git submodule update --init --recursive &&\ + pip install . # Instal apex of ROCm due to dependency of Megatron apex_rocm: @@ -263,7 +263,7 @@ apex_rocm: elif [ "$$(expr $(TORCH_MAJOR_VERSION) == 2)" -eq 1 ] && [ "$$(expr $(TORCH_MINOR_VERSION) == 0)" -eq 1 ]; then \ git checkout release/1.0.0 ; \ elif [ "$$(expr $(TORCH_MAJOR_VERSION) == 1)" -eq 1 ]; then \ - git checkout release/1.0.0 ; \ + git checkout release/1.0.0 ; \ fi pip install -v --disable-pip-version-check --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./apex @@ -275,11 +275,11 @@ ifeq ($(shell echo $(CUDA_VER)">=12.9" | bc -l), 1) git clone --single-branch --branch main https://github.com/Azure/msccl.git \ && git -C msccl checkout 87048bd && git -C msccl submodule update --recursive --init else ifeq ($(shell echo $(CUDA_VER)">=12.8" | bc -l), 1) - # Get commit 87048bd from msscl to support updated nccl and sm_100 + # Get commit 87048bd from msscl to support updated nccl and sm_100 $(eval ARCHS := 75 80 86 89 90 100) if [ -d msccl ]; then rm -rf msccl; fi; \ git clone --single-branch --branch main https://github.com/Azure/msccl.git \ - && git -C msccl checkout 87048bd && git -C msccl submodule update --recursive --init + && git -C msccl checkout 87048bd && git -C msccl submodule update --recursive --init else ifeq ($(shell echo $(CUDA_VER)">=11.8" | bc -l), 1) $(eval ARCHS := 70 75 80 86 89 90) else @@ -311,3 +311,24 @@ endif nvbandwidth: sb_micro_path cd ./nvbandwidth && git apply ../nvbandwidth.patch && cp ../nvbandwidth_testcases_patched.h ./testcases_patched.h && cmake . && make && cd .. cp -v ./nvbandwidth/nvbandwidth $(SB_MICRO_PATH)/bin + +# Build nvbench +cuda_nvbench: sb_micro_path +ifeq ($(shell echo $(CUDA_VER)">=12.9" | bc -l), 1) + $(eval ARCHS := "100;103") +else ifeq ($(shell echo $(CUDA_VER)">=12.8" | bc -l), 1) + $(eval ARCHS := "90;100") +else ifeq ($(shell echo $(CUDA_VER)">=11.8" | bc -l), 1) + $(eval ARCHS := "70;75;80;86;89;90") +else + $(eval ARCHS := "70;75;80;86") +endif + cd ./nvbench && mkdir -p build && cd build && \ + cmake \ + -DCMAKE_INSTALL_PREFIX=$(SB_MICRO_PATH) \ + -DCMAKE_CUDA_ARCHITECTURES=$(ARCHS) \ + -DNVBench_ENABLE_CUPTI=ON \ + -DCMAKE_BUILD_TYPE=Release \ + .. && \ + make -j $(NUM_MAKE_JOBS) && \ + make install diff --git a/third_party/nvbench b/third_party/nvbench new file mode 160000 index 000000000..7feda2cf3 --- /dev/null +++ b/third_party/nvbench @@ -0,0 +1 @@ +Subproject commit 7feda2cf3ade88b3e73a0e0414ba543a4fbfbc43