Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Include essential project files
include README.md
include LICENSE
include pyproject.toml

# Include all Python source files
recursive-include src/madengine *.py

# Include all script files
recursive-include src/madengine/scripts *

# Include database schema files
include src/madengine/db/*.sql

# Include documentation files
recursive-include src/madengine *.md

# Include any configuration or data files
recursive-include src/madengine *.yml
recursive-include src/madengine *.yaml
recursive-include src/madengine *.json
recursive-include src/madengine *.toml
recursive-include src/madengine *.cfg
recursive-include src/madengine *.ini

# Include shell scripts and executables
recursive-include src/madengine *.sh
recursive-include src/madengine *.bash

# Include any template or configuration files
recursive-include src/madengine *.template
recursive-include src/madengine *.conf

# Exclude compiled Python files
global-exclude *.pyc
global-exclude *.pyo
global-exclude __pycache__

# Exclude version control
global-exclude .git*
global-exclude .svn*

# Exclude IDE and editor files
global-exclude .vscode*
global-exclude .idea*
global-exclude *.swp
global-exclude *.swo
global-exclude *~

# Exclude build and distribution artifacts
global-exclude build
global-exclude dist
global-exclude *.egg-info

# Exclude test artifacts
global-exclude .pytest_cache
global-exclude .coverage
global-exclude htmlcov

# Exclude temporary and log files
global-exclude *.log
global-exclude *.tmp
global-exclude temp*
37 changes: 33 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,14 @@ dynamic = ["version"]
authors = [
{ name="Advanced Micro Devices", email="mad.support@amd.com" },
]
description = "MAD Engine is a set of interfaces to run various AI models from public MAD."
maintainers = [
{ name="Advanced Micro Devices", email="mad.support@amd.com" },
]
description = "AI Models automation and dashboarding CLI tool for running LLMs and Deep Learning models"
readme = "README.md"
license = {text = "MIT"}
requires-python = ">=3.8"
keywords = ["AI", "machine-learning", "deep-learning", "LLM", "automation", "AMD", "ROCm", "GPU", "performance", "benchmarking"]
dependencies = [
"pandas",
"GitPython",
Expand All @@ -21,15 +26,28 @@ dependencies = [
"mysql-connector-python",
"pymysql",
"tqdm",
"pytest",
"typing-extensions",
"pymongo",
"toml",
"numpy",
"pynvml",
]
classifiers = [
"Programming Language :: Python :: 3",
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Operating System :: POSIX :: Linux",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Testing",
"Topic :: System :: Benchmark",
"Topic :: System :: Hardware",
]

[project.scripts]
Expand All @@ -49,11 +67,22 @@ dev = [
"pytest-asyncio",
]

[tool.hatch.build.targets.sdist]
include = [
"/src",
"/README.md",
"/LICENSE",
"/pyproject.toml",
]

[tool.hatch.build.targets.wheel]

[tool.hatch.build.targets.wheel.force-include]
"src/madengine/scripts" = "madengine/scripts"

[tool.hatch.build.targets.sdist.force-include]
"src/madengine/scripts" = "src/madengine/scripts"

[tool.hatch.version]
source = "versioningit"

Expand Down
3 changes: 2 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[pytest]
testpaths = tests
pythonpath = src
pythonpath = src
markers = packaging: packaging-related tests exercising the wheel build/install flow
92 changes: 61 additions & 31 deletions src/madengine/db/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,67 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import mapper, clear_mappers

# MAD Engine modules
from logger import setup_logger
from base_class import BASE, BaseMixin
from utils import get_env_vars
# MAD Engine modules (dual import: package first, then standalone fallback for scp use)
try: # Package import context
from madengine.db.logger import setup_logger # type: ignore
from madengine.db.base_class import BASE, BaseMixin # type: ignore
from madengine.db.utils import get_env_vars # type: ignore
except ImportError: # Standalone (scp) execution context
from logger import setup_logger # type: ignore
from base_class import BASE, BaseMixin # type: ignore
from utils import get_env_vars # type: ignore


# Create the logger
LOGGER = setup_logger()
# Get the environment variables
ENV_VARS = get_env_vars()

# Check if the environment variables are set
if ENV_VARS["user_name"] is None or ENV_VARS["user_password"] is None:
raise ValueError("User name or password not set")
# Global engine variable - will be lazily initialized
ENGINE = None

if ENV_VARS["db_hostname"] is None or ENV_VARS["db_port"] is None:
raise ValueError("DB hostname or port not set")

if ENV_VARS["db_name"] is None:
raise ValueError("DB name not set")
def get_engine():
"""Get database engine, creating it lazily when first needed.

Returns:
sqlalchemy.engine.Engine: Database engine

Raises:
ValueError: If required environment variables are not set
"""
global ENGINE

if ENGINE is None:
# Check if the environment variables are set
if not ENV_VARS["user_name"] or not ENV_VARS["user_password"]:
raise ValueError("User name or password not set")

if not ENV_VARS["db_hostname"] or not ENV_VARS["db_port"]:
raise ValueError("DB hostname or port not set")

if not ENV_VARS["db_name"]:
raise ValueError("DB name not set")

# Create the engine
ENGINE = create_engine(
"mysql+pymysql://{user_name}:{user_password}@{hostname}:{port}/{db_name}".format(
user_name=ENV_VARS["user_name"],
user_password=ENV_VARS["user_password"],
hostname=ENV_VARS["db_hostname"],
port=ENV_VARS["db_port"],
db_name=ENV_VARS["db_name"],
)
)
LOGGER.info("Database engine created for %s@%s:%s/%s",
ENV_VARS["user_name"], ENV_VARS["db_hostname"],
ENV_VARS["db_port"], ENV_VARS["db_name"])

return ENGINE

# Create the engine
ENGINE = create_engine(
"mysql+pymysql://{user_name}:{user_password}@{hostname}:{port}/{db_name}".format(
user_name=ENV_VARS["user_name"],
user_password=ENV_VARS["user_password"],
hostname=ENV_VARS["db_hostname"],
port=ENV_VARS["db_port"],
db_name=ENV_VARS["db_name"],
)
)
# Check for eager initialization
if os.getenv("MADENGINE_DB_EAGER") == "1":
LOGGER.info("MADENGINE_DB_EAGER=1 detected, creating engine immediately")
ENGINE = get_engine()

# Define the path to the SQL file
SQL_FILE_PATH = os.path.join(os.path.dirname(__file__), 'db_table_def.sql')
Expand Down Expand Up @@ -99,20 +129,20 @@ def connect_db() -> None:
user_name = ENV_VARS["user_name"]

try:
ENGINE.execute("Use {}".format(db_name))
get_engine().execute("Use {}".format(db_name))
return
except OperationalError: # as err:
LOGGER.warning(
"Database %s does not exist, attempting to create database", db_name
)

try:
ENGINE.execute("Create database if not exists {}".format(db_name))
get_engine().execute("Create database if not exists {}".format(db_name))
except OperationalError as err:
LOGGER.error("Database creation failed %s for username: %s", err, user_name)

ENGINE.execute("Use {}".format(db_name))
ENGINE.execute("SET GLOBAL max_allowed_packet=4294967296")
get_engine().execute("Use {}".format(db_name))
get_engine().execute("SET GLOBAL max_allowed_packet=4294967296")


def clear_db() -> None:
Expand All @@ -126,7 +156,7 @@ def clear_db() -> None:
db_name = ENV_VARS["db_name"]

try:
ENGINE.execute("DROP DATABASE IF EXISTS {}".format(db_name))
get_engine().execute("DROP DATABASE IF EXISTS {}".format(db_name))
return
except OperationalError: # as err:
LOGGER.warning("Database %s could not be dropped", db_name)
Expand All @@ -143,13 +173,13 @@ def show_db() -> None:
db_name = ENV_VARS["db_name"]

try:
result = ENGINE.execute(
result = get_engine().execute(
"SELECT * FROM {} \
WHERE {}.created_date= \
(SELECT MAX(created_date) FROM {}) ;".format(DB_TABLE.__tablename__)
)
for row in result:
print(row)
LOGGER.info("Latest entry: %s", row)
return
except OperationalError: # as err:
LOGGER.warning("Database %s could not be shown", db_name)
Expand Down Expand Up @@ -195,7 +225,7 @@ def trim_column(col_name: str) -> None:
Raises:
OperationalError: An error occurred while trimming the column.
"""
ENGINE.execute(
get_engine().execute(
"UPDATE {} \
SET \
{} = TRIM({});".format(
Expand All @@ -218,7 +248,7 @@ def get_column_names() -> list:
"""
db_name = ENV_VARS["db_name"]

result = ENGINE.execute(
result = get_engine().execute(
"SELECT `COLUMN_NAME` \
FROM `INFORMATION_SCHEMA`.`COLUMNS` \
WHERE `TABLE_SCHEMA`='{}' \
Expand Down
15 changes: 9 additions & 6 deletions src/madengine/db/database_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
# built-in modules
import typing

# MAD Engine modules
from database import ENGINE
# MAD Engine modules (dual import)
try:
from madengine.db.database import get_engine, LOGGER # type: ignore
except ImportError:
from database import get_engine, LOGGER # type: ignore


def get_all_gpu_archs() -> typing.List[str]:
Expand All @@ -18,7 +21,7 @@ def get_all_gpu_archs() -> typing.List[str]:
Returns:
typing.List[str]: A list of all GPU architectures in the database.
"""
matching_entries = ENGINE.execute(
matching_entries = get_engine().execute(
"SELECT DISTINCT(gpu_architecture) FROM dlm_table"
)

Expand All @@ -43,7 +46,7 @@ def get_matching_db_entries(
Returns:
typing.List[typing.Dict[str, typing.Any]]: The matching entries.
"""
print(
LOGGER.info(
"Looking for entries with {}, {} and {}".format(
recent_entry["model"],
recent_entry["gpu_architecture"],
Expand All @@ -52,7 +55,7 @@ def get_matching_db_entries(
)

# find matching entries to current entry
matching_entries = ENGINE.execute(
matching_entries = get_engine().execute(
"SELECT * FROM dlm_table \
WHERE model='{}' \
AND gpu_architecture='{}' \
Expand All @@ -74,7 +77,7 @@ def get_matching_db_entries(
if should_add:
filtered_matching_entries.append(m)

print(
LOGGER.info(
"Found {} similar entries in database filtered down to {} entries".format(
len(matching_entries),
len(filtered_matching_entries)
Expand Down
Loading