diff --git a/dataset-generation/README b/dataset-generation/README new file mode 100644 index 0000000..39989cd --- /dev/null +++ b/dataset-generation/README @@ -0,0 +1,2 @@ +For dataset generation scripts, a minimum Python version of 3.10+ is required +Install requirements listed in requirements.txt \ No newline at end of file diff --git a/dataset-generation/create_linux_db.py b/dataset-generation/create_linux_db.py new file mode 100644 index 0000000..53369c5 --- /dev/null +++ b/dataset-generation/create_linux_db.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +import argparse +import requests +import gzip + +from enum import Enum +from dataclasses import dataclass +from pathlib import Path, PurePosixPath +from datetime import datetime +from io import BytesIO, TextIOWrapper +from http import HTTPStatus +from debian.deb822 import Deb822, Sources +from tqdm.auto import tqdm + +from typing import ClassVar +from collections.abc import Mapping, Sequence + +from dapper_python.databases_v2.database import Metadata +from dapper_python.databases_v2.linux_db import LinuxDatabase, PackageFile, PackageSource + + +class LinuxDistro: + class Distro(Enum): + """Currently supported distros""" + UBUNTU = 'ubuntu' + DEBIAN = 'debian' + + @dataclass + class _DistroInfo: + archive_url: str + contents_path: str + sources_path: str + + @property + def contents_url(self) -> str: + return self.archive_url + self.contents_path + + @property + def sources_url(self) -> str: + return self.archive_url + self.sources_path + + def __init__(self, distro: Distro, release: str) -> None: + try: + candidate_infos = self.DISTRO_MAP[distro] + except KeyError as e: + raise KeyError(f"Invalid distro: {distro}") from e + + # Check if the release actually exists, if we get a non-404 then it means it likely exists + if not isinstance(candidate_infos, Sequence): + candidate_infos = (candidate_infos,) + for candidate_info in candidate_infos: + with requests.head(candidate_info.contents_url.format(release=release)) as response: + if response.status_code != HTTPStatus.NOT_FOUND: + self._dist_info = candidate_info + break + else: # Exits loop without break + raise ValueError(f"Release {release} does not exist for distro \"{distro.value}\"") + + self._distro = distro + self._release = release + + def get_contents(self, **kwargs) -> TextIOWrapper: + """Downloads the contents file for the distro + release""" + data, _ = self.get_file(self._dist_info.contents_path.format(release=self._release), **kwargs) + with gzip.open(data) as gz_file: + return TextIOWrapper(BytesIO(gz_file.read()), encoding="utf-8") + + def get_sources(self, **kwargs) -> TextIOWrapper: + """Downloads the sources file for the distro + release""" + data, _ = self.get_file(self._dist_info.sources_path.format(release=self._release), **kwargs) + with gzip.open(data) as gz_file: + return TextIOWrapper(BytesIO(gz_file.read()), encoding="utf-8") + + def get_file(self, path: str, *, progress_params: Mapping | bool = False) -> tuple[BytesIO, str | None]: + """Utility function for downloading files from the distro archive""" + url = self._dist_info.archive_url + path + with requests.get(url, stream=True) as response: + response.raise_for_status() + if 'content-length' in response.headers: + file_size = int(response.headers['content-length']) + else: + file_size = None + + _progress_params = { + "total": file_size, + "desc": "Downloading file", + "unit": 'B', + "unit_divisor": 1024, + "unit_scale": True, + "position": None, + "leave": None, + } + if isinstance(progress_params, Mapping): + _progress_params.update(progress_params) + elif isinstance(progress_params, bool): + _progress_params["disable"] = not progress_params + + content = BytesIO() + with tqdm(**_progress_params) as progress_bar: + for chunk in response.iter_content(chunk_size=8 * 1024): + content.write(chunk) + progress_bar.update(len(chunk)) + + content.seek(0) + return content, response.headers.get('Content-Type', None) + + DISTRO_MAP: ClassVar[dict[Distro, _DistroInfo]] = { + Distro.UBUNTU: _DistroInfo( + archive_url=r"https://archive.ubuntu.com/ubuntu/", + contents_path=r"dists/{release}/Contents-amd64.gz", + sources_path=r"dists/{release}/main/source/Sources.gz", + ), + Distro.DEBIAN: ( + # Debian has two different sites for currently supported distros and older archived distros + # Need to check both + _DistroInfo( + archive_url=r"https://deb.debian.org/debian/", + contents_path=r"dists/{release}/main/Contents-amd64.gz", + sources_path=r"dists/{release}/main/source/Sources.gz", + ), + _DistroInfo( + archive_url=r"https://archive.debian.org/debian/", + contents_path=r"dists/{release}/main/Contents-amd64.gz", + sources_path=r"dists/{release}/main/source/Sources.gz", + ), + ), + } + + +def main(): + parser = argparse.ArgumentParser( + description="Create Linux DB by parsing the Linux Contents file", + ) + parser.add_argument( + "-o", "--output", + required=False, + type=Path, default=Path('LinuxPackageDB.db'), + help='Path of output (database) file to create. Defaults to "LinuxPackageDB.db" in the current working directory', + ) + parser.add_argument( + '-v', '--version', + type=int, required=True, + help='Version marker for the database to keep track of changes', + ) + + parser.add_argument( + "distro", + type=LinuxDistro.Distro, choices=LinuxDistro.Distro, + help="Name of the distro to scrape", + ) + parser.add_argument( + "release", + type=str, + help="Name of the release to scrape", + ) + args = parser.parse_args() + + # Currently not set up to be able to handle resuming a previously started database + # It's not a high priority as the process only takes few minutes to process. Can delete the old DB and recreate + if args.output.exists(): + raise FileExistsError(f"File {args.output} already exists") + + linux_distro = LinuxDistro(args.distro, args.release) + + linux_db = LinuxDatabase.create_database(args.output, exist_ok=False) + with linux_db.session() as session: + # Parse contents file + with session.begin(): + contents_data = linux_distro.get_contents(progress_params=True) + entry_count = sum(1 for _ in contents_data) + contents_data.seek(0) + + # Operate using generator expressions for more efficient memory usage + progress_iter = tqdm( + contents_data, + total=entry_count, + desc='Processing Contents', colour='green', + unit='Entry', + ) + parsed_lines = ( + tuple(x.strip() for x in entry.rsplit(maxsplit=1)) + for entry in progress_iter + ) + package_files = ( + PackageFile( + file_path=PurePosixPath(file_path), + package_name=full_package_name.rsplit('/', maxsplit=1)[-1], + full_package_name=full_package_name, + ) + for file_path, full_package_name in parsed_lines + ) + session.bulk_insert(package_files, batch_size=50_000) + + # Parse sources file + with session.begin(): + sources_data = linux_distro.get_sources(progress_params=True) + entry_count = sum(1 for _ in Deb822.iter_paragraphs(sources_data)) + sources_data.seek(0) + + # Operate using generator expressions for more efficient memory usage + progress_iter = tqdm( + Deb822.iter_paragraphs(sources_data), + total=entry_count, + desc='Processing Sources', colour='cyan', + unit='Entry', + ) + package_sources = ( + PackageSource( + package_name=entry.get("Package"), + bin_package=bin_package.strip(), + ) + for entry in progress_iter + for bin_package in entry.get('Binary').split(',') + ) + session.bulk_insert(package_sources, batch_size=50_000) + + # Set version + with session.begin(): + session.add(Metadata( + version=args.version, + format="Linux", + timestamp=int(datetime.now().timestamp()), + )) + + +if __name__ == "__main__": + main() diff --git a/dataset-generation/create_maven_db.py b/dataset-generation/create_maven_db.py new file mode 100644 index 0000000..8ca0c41 --- /dev/null +++ b/dataset-generation/create_maven_db.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import argparse +import requests +import math +import time + +from datetime import datetime +from pathlib import Path +from http import HTTPStatus +from tqdm.auto import tqdm + +from dapper_python.databases_v2.database import Metadata +from dapper_python.databases_v2.maven_db import MavenDatabase, Package, PackageFile + +MAVEN_API_URL = "https://search.maven.org/solrsearch/select" + + +def main(): + parser = argparse.ArgumentParser( + description="Create java DB from Maven packages", + ) + parser.add_argument( + "-o", "--output", + required=False, + type=Path, default=Path("MavenPackageDB.db"), + help="Path of output (database) file to create. Defaults to \"MavenPackageDB.db\" in the current working directory", + ) + parser.add_argument( + "-v", "--version", + type=int, required=True, + help="Version marker for the database to keep track of changes", + ) + args = parser.parse_args() + + # Currently not set up to be able to handle resuming a previously started database + # Due to the way the Maven API returns data, it needs to be done in one session + if args.output.exists(): + raise FileExistsError(f"File {args.output} already exists") + + query_params = { + "q": "*:*", # Query all packages + "rows": 0, # Number of results per page + "start": 0, # Offset for pagination + "wt": "json", # JSON output + } + with requests.get(MAVEN_API_URL, params=query_params) as response: + response.raise_for_status() + init_data = response.json() + num_entries = init_data["response"]["numFound"] + if not num_entries: + print("No packages found") + return + + maven_db = MavenDatabase.create_database(args.output, exist_ok=False) + with maven_db.session() as session: + with session.begin(): + # Can request a maximum of 200 entries + CHUNK_SIZE = 200 + + progress_bar = tqdm( + total=num_entries, + desc="Processing packages", colour="green", + unit="Package", + position=None, leave=None, + disable=not num_entries, + ) + for page in range(math.ceil(num_entries / CHUNK_SIZE)): + query_params = { + "q": "*:*", + "rows": CHUNK_SIZE, + "start": page, + "wt": "json", + } + with requests.get(MAVEN_API_URL, params=query_params) as response: + response.raise_for_status() + + data = response.json() + pacakge_entries = data["response"]["docs"] + + packages = [] + for entry in pacakge_entries: + group_id, _, package_name = entry["id"].partition(":") + package = Package( + package_name=package_name, + group_id=group_id, + timestamp=entry["timestamp"], + files=[ + PackageFile(file_name=entry["a"] + suffix) + for suffix in entry["ec"] + ], + ) + packages.append(package) + session.bulk_insert(packages) + progress_bar.update(len(pacakge_entries)) + + # Try to rate-limit the requests since it's causing problems + time.sleep(1) + + # Set version + with session.begin(): + session.add(Metadata( + version=args.version, + format="Maven", + timestamp=int(datetime.now().timestamp()), + )) + + +if __name__ == "__main__": + main() diff --git a/dataset-generation/create_mingw_db.py b/dataset-generation/create_mingw_db.py new file mode 100644 index 0000000..0db20a4 --- /dev/null +++ b/dataset-generation/create_mingw_db.py @@ -0,0 +1,417 @@ +from __future__ import annotations + +import re +import argparse +import json +import tarfile +import warnings +import logging +import tempfile +import more_itertools +import urllib3 +import requests +import zstandard as zstd # Newer python has this built in, but only in 3.14+ +import magic +import angr +import pydemumble + +from io import BytesIO +from pathlib import Path, PurePosixPath +from datetime import datetime +from dataclasses import dataclass +from sqlmodel import select, delete +from contextlib import contextmanager, suppress, ExitStack +from tarfile import TarFile +from bs4 import BeautifulSoup, Tag +from tqdm.auto import tqdm +from methodtools import lru_cache + +try: + from enum import StrEnum +except ImportError: + from backports.strenum import StrEnum + +from typing import Any +from typing_extensions import Self + +from dapper_python.databases_v2.database import Metadata +from dapper_python.databases_v2.mingw_db import MinGWDatabase +from dapper_python.databases_v2.mingw_db import Package, PackageFile, SourceFile +from dapper_python.databases_v2.mingw_db import FunctionSymbol +from dapper_python.dataset_generation.parsing.cpp import CPPTreeParser +from dapper_python.dataset_generation.utils.archive import SafeTarFile, SafeZipFile + +# Note: Using verify=False for requests is not ideal, but otherwise breaks due to corporate network certificates +PACKAGE_INDEX_URL = "https://packages.msys2.org/packages" + + +class Arch(StrEnum): + """The architecture options available on MySYS2""" + UCRT_64 = "ucrt64" + CLANG_64 = "clang64" + CLANG_ARM_64 = "clangarm64" + MYSYS = "mysys" + MINGW_64 = "mingw64" + MINGW_32 = "mingw32" + + +@dataclass +class MySysPackage: + package_name: str + package_version: str + package_url: str + description: str | None = None + + @property + def source_url(self) -> str: + source_url, _ = self._fetch_artifact_urls() + return source_url + + def get_source(self) -> BytesIO: + """Gets the source tarball for the package + + Returned bytes should be opened as a tarfile + """ + with suppress_warnings(): + with requests.get(self.source_url, verify=False, stream=True) as response: + with zstd.ZstdDecompressor().stream_reader(response.content) as reader: + decompressed_tarball = BytesIO(reader.read()) + return decompressed_tarball + + @property + def contents_url(self) -> str: + _, binary_url = self._fetch_artifact_urls() + return binary_url + + def get_contents(self) -> BytesIO: + """Gets the package contents tarball for the pacakge + + Returned bytes should be opened as a tarfile + """ + with suppress_warnings(): + with requests.get(self.contents_url, verify=False, stream=True) as response: + with zstd.ZstdDecompressor().stream_reader(response.content) as reader: + decompressed_tarball = BytesIO(reader.read()) + return decompressed_tarball + + @lru_cache(maxsize=1) + def _fetch_artifact_urls(self) -> tuple[str, str]: + with suppress_warnings(), requests.get(self.package_url, verify=False) as response: + soup = BeautifulSoup(response.text, 'html.parser') + + def find_entry(label: str) -> Tag | None: + """Finds and returns the tag following the
with matching label (case-insensitive).""" + dt = soup.find("dt", string=re.compile(rf"^{label}\s*:?$", re.I)) + if not dt: + return None + dd = dt.find_next_sibling("dd") + if not dd: + return None + return dd.find("a", href=True) + + source_url = find_entry("Source-Only Tarball").get("href") + binary_url = find_entry("File").get("href") + + return source_url, binary_url + + +class PackageAnalyzer: + def __init__(self, package: MySysPackage) -> None: + self._mysys_package = package + + self._exit_stack = ExitStack() + self._temp_dir: Path | None = None + self._source_dir: Path | None = None + self._package_dir: Path | None = None + + def __enter__(self) -> Self: + self._temp_dir = Path(self._exit_stack.enter_context(tempfile.TemporaryDirectory())) + self._source_dir = self._temp_dir.joinpath("source") + self._source_dir.mkdir(exist_ok=True) + self._package_dir = self._temp_dir.joinpath("package") + self._package_dir.mkdir(exist_ok=True) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._temp_dir = None + self._source_dir = None + self._package_dir = None + return self._exit_stack.__exit__(exc_type, exc_val, exc_tb) + + def analyze_package(self) -> Package | None: + """Analyzes the package and returns the parsed data""" + if self._temp_dir is None: + raise RuntimeError("Must be used within context manager") + + mingw_package = Package( + package_name=self._mysys_package.package_name, + package_version=self._mysys_package.package_version, + ) + + try: + analyzed_package_sources = self._analyze_package_source() + mingw_package.source_files = analyzed_package_sources + except (zstd.ZstdError, tarfile.ReadError): + return None + + with suppress(zstd.ZstdError, tarfile.ReadError): + analyzed_package_files, symbols = self._analyze_package_contents() + mingw_package.package_files = analyzed_package_files + + if symbols: + for source_file in mingw_package.source_files: + for function_symbol in source_file.functions: + # Demangled symbols to not contain return type and many compiled types do not match + # Such as std::string -> std::__cxx11::basic_string, std::allocator> + function_symbol.in_binary = function_symbol.qualified_symbol_name in symbols + + return mingw_package + + def _analyze_package_source(self) -> list[SourceFile]: + if self._source_dir is None: + raise RuntimeError("Must be used within context manager") + + with suppress_warnings(), SafeTarFile.open(fileobj=self._mysys_package.get_source(), mode="r:*") as outer_tar: + file_list = outer_tar.getmembers() + + sub_files = [x for x in file_list if ".tar" in x.name] + for sub_tar in sub_files: + data = BytesIO(outer_tar.extractfile(sub_tar).read()) + with SafeTarFile.open(fileobj=data, mode="r:*") as inner_tar: + inner_tar.safe_extractall(self._source_dir) + + sub_files = [x for x in file_list if x.name.endswith(".zip")] + for sub_zip in sub_files: + data = BytesIO(outer_tar.extractfile(sub_zip).read()) + with SafeZipFile(data) as inner_zip: + inner_zip.safe_extractall(self._source_dir) + + dirs = [x for x in self._source_dir.iterdir() if x.is_dir()] + source_root = dirs[0] if len(dirs) == 1 else self._source_dir + + # Process all C/C++ files + files = [ + x for x in source_root.rglob("*") + if x.suffix.lower() in (".c", ".cpp", ".h", ".hpp", ".tpp") + and x.is_file() + ] + + source_files = [] + file_progress_iter = tqdm( + files, + desc="Parsing Files", colour="cyan", + unit="File", + position=None, leave=None, + disable=not files, + ) + for file in file_progress_iter: + source_file = SourceFile( + file_path=PurePosixPath(file.relative_to(source_root)), + ) + + tree = CPPTreeParser.from_source(file.read_bytes()) + source_file.functions = [ + FunctionSymbol( + return_type=x.return_type, + symbol_name=x.symbol_name, + qualified_symbol_name=x.qualified_symbol_name, + params=x.params, + full_signature=x.full_signature, + source_text=x.source_text, + ) + for x in tree.parse_functions() + ] + source_files.append(source_file) + + return source_files + + def _analyze_package_contents(self) -> tuple[list[PackageFile], set[str]]: + if self._package_dir is None: + raise RuntimeError("Must be used within context manager") + + with suppress_warnings(), TarFile.open(fileobj=self._mysys_package.get_contents(), mode="r:*") as tar_file: + tar_file.extractall(self._package_dir) + + dirs = [x for x in self._package_dir.iterdir() if x.is_dir()] + package_root = next( + (x for x in dirs if any(y == x.name for y in Arch)), + self._package_dir, + ) + files = [x for x in package_root.rglob("*") if x.is_file()] + + package_files = [] + symbols = set() + file_progress_iter = tqdm( + files, + desc="Parsing Files", colour="cyan", + unit="File", + position=None, leave=None, + disable=not files, + ) + for file in file_progress_iter: + try: + mime_type = magic.from_file(str(file.absolute()), mime=True) + magic_string = magic.from_file(str(file.absolute())) + except magic.MagicException: + mime_type = None + magic_string = None + + # Scan the file for any symbols + with disable_logging(), suppress(Exception): + angr_proj = angr.Project(file, auto_load_libs=False) + demangled_symbols = ( + pydemumble.demangle(x.name).strip() + for x in angr_proj.loader.main_object.symbols + ) + demangled_functions = (x for x in demangled_symbols if not x.startswith(self._NON_FUNCTION_PREFIXES)) + # Grab just the qualified name (without parameters) to compare as many compiled types are different from their source-code counterpart + symbols.update(x.split("(")[0] for x in demangled_functions) + + package_file = PackageFile( + file_path=PurePosixPath(file.relative_to(package_root)), + mime_type=mime_type, + magic_string=magic_string, + ) + package_files.append(package_file) + + return package_files, symbols + + _NON_FUNCTION_PREFIXES = ( + "sub_", # Special case since we don't want anonymous functions/code sections angr finds without a name + "vtable for", + "typeinfo for", + "typeinfo name for", + "covariant return thunk for", + "covariant return thunk to", + "construction vtable for", + "virtual thunk to", + "non-virtual thunk to", + "guard variable for", + "transaction clone for", + "VTT for", + "TLS wrapper function for", + ) + + +@contextmanager +def suppress_warnings(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=urllib3.exceptions.InsecureRequestWarning) + warnings.simplefilter("ignore", category=RuntimeWarning) + yield + + +@contextmanager +def disable_logging(highest_level=logging.CRITICAL): + previous_level = logging.root.manager.disable + try: + logging.disable(highest_level) + yield + finally: + logging.disable(previous_level) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-o", "--output", + type=Path, default=Path("MinGWDB.db"), + help="Name of the output database file", + ) + parser.add_argument( + "-v", "--version", + type=int, required=True, + help="Version of the database", + ) + args = parser.parse_args() + + params = {"repo": Arch.MINGW_64} + with suppress_warnings(), requests.get(PACKAGE_INDEX_URL, params=params, verify=False) as response: + response.raise_for_status() + soup = BeautifulSoup(response.text, 'html.parser') + + # Find table in the page. + # Unfortunate lack of clearly distinct identifiers to search for (e.g id="package_list") + # So we need to follow the chain of tags and hope it doesn't change + table = soup.find("table", class_="table-hover") + tbody = table.find("tbody") + + package_list: dict[str, MySysPackage] = {} + for row in tbody.find_all("tr"): + cols = row.find_all("td") + if len(cols) != 3: + continue + + link_tag = cols[0].find("a") + package_link = link_tag['href'] + package_name = link_tag.text.strip() + version = cols[1].text.strip() + description = cols[2].text.strip() + + package_list[package_name] = MySysPackage( + package_name=package_name, + package_version=version, + description=description, + package_url=package_link, + ) + + mingw_db = MinGWDatabase.create_database(args.output, exist_ok=True) + with mingw_db.session() as session: + # Remove any outdated packages + with session.begin(): + to_update = more_itertools.peekable(( + package + for package in session.exec(select(Package)) + if package.package_name in package_list + and package_list.get(package.package_name).package_version != package.package_version + )) + progress_iter = tqdm( + to_update, + desc="Removing outdated packages", + unit="Package", + position=None, leave=False, + disable=not to_update, + ) + session.bulk_delete(progress_iter) + + # noinspection PyTypeChecker, Pydantic + saved_packages: set[str] = set(session.exec(select(Package.package_name))) + to_update = sorted(list(set(package_list.keys()) - saved_packages)) + to_update = [ + package_list[package_name] + for package_name in to_update + ] + + # Get new packages and add to the database + progress_iter = tqdm( + to_update, + desc="Scraping Packages", colour="blue", + unit="Package", + position=None, leave=None, + disable=not to_update, + ) + for package in progress_iter: + with PackageAnalyzer(package) as analyzer: + mingw_package = analyzer.analyze_package() + if not mingw_package: + continue + + with session.begin(): + session.add(mingw_package) + + # Due to somewhat high memory usage, free up memory before the next loop starts + del mingw_package + mingw_package = None + + # Reset the metadata if it already exists and set new version + with session.begin(): + session.exec(delete(Metadata)) + session.add(Metadata( + version=args.version, + format="PyPI", + timestamp=int(datetime.now().timestamp()), + )) + + +if __name__ == "__main__": + main() diff --git a/dataset-generation/create_pypi_db.py b/dataset-generation/create_pypi_db.py new file mode 100644 index 0000000..681509d --- /dev/null +++ b/dataset-generation/create_pypi_db.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +import argparse +import requests +import zipfile, zlib +import functools +import methodtools +import more_itertools +import magic + +from pathlib import Path, PurePosixPath +from dataclasses import dataclass +from sqlmodel import select, delete +from datetime import datetime +from io import BytesIO +from zipfile import ZipFile +from contextlib import suppress, ExitStack +from natsort import natsorted +from tqdm.auto import tqdm + +from collections.abc import Generator +from typing import ClassVar, Any +from typing_extensions import Self + +from dapper_python.utils import yet_more_itertools +from dapper_python.databases_v2.database import Metadata +from dapper_python.databases_v2.python_db import PyPIDatabase, Package, PackageImport, PackageFile +from dapper_python.dataset_generation.utils.scraping import get_with_retry +from dapper_python.dataset_generation.utils.archive import SafeZipFile +from dapper_python.dataset_generation.utils.futures import BoundedThreadPoolExecutor + +PYPI_INDEX_URL = 'https://pypi.python.org/simple/' + + +@dataclass +class PyPIPackage: + package_name: str + + @methodtools.lru_cache(maxsize=1) + def fetch_metadata(self) -> dict[str, Any]: + """Gets the information contained on the package's PyPI page in json format + + :return: JSON-formatted data retrieved from the endpoint + """ + url = self._API_PACKAGE_URL.format(package_name=self.package_name) + with get_with_retry(url) as response: + return response.json() + + def fetch_wheels(self) -> Generator[Wheel, None, None]: + """Gets the wheel files for the package""" + package_info = self.fetch_metadata() + + # Only keep ones that have wheels and have not been yanked + releases = dict(natsorted(package_info['releases'].items(), reverse=True)) + releases = { + version: data + for version, data in releases.items() + if any(( + x['packagetype'] == 'bdist_wheel' + and not x['yanked'] + for x in data + )) + } + if not releases: + return None + + # Grab all wheels (for all architectures) from the latest version that has not been yanked and has some wheels + version, release_data = next(iter(releases.items())) + for entry in release_data: + if not entry['packagetype'] == 'bdist_wheel': + continue + + with get_with_retry(entry['url'], stream=True) as response: + data = BytesIO(response.content) + with suppress(zipfile.BadZipFile): + yield Wheel(SafeZipFile(data)) + return None + + _API_PACKAGE_URL: ClassVar[str] = "https://pypi.org/pypi/{package_name}/json" + + +@dataclass +class Wheel: + archive: ZipFile + + def __post_init__(self) -> None: + self._exit_stack = ExitStack() + + def __enter__(self) -> Self: + self._exit_stack.enter_context(self.archive) + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + return self._exit_stack.__exit__(exc_type, exc_val, exc_tb) + + def get_imports(self) -> set[str]: + """Tries to get a list of names importable from the package""" + return self._get_top_level_imports() | self._infer_imports() + + def _get_top_level_imports(self) -> set[str]: + """Tries to get names importable from the package using the top-level.txt file""" + package_files = [PurePosixPath(x) for x in self.archive.namelist()] + + # Sometimes contains a top_level.txt file which details the top-level imports for the package + # If this is available, then use it as it's likely to be the most reliable information + top_level_txt = next((x for x in package_files if x.name == "top_level.txt"), None) + if not top_level_txt: + return set() + + content = self.archive.read(str(top_level_txt)).decode("utf-8") + imports = {line.strip() for line in content.splitlines() if line} + return {x for x in imports if x} + + def _infer_imports(self) -> set[str]: + """Tries to infer names importable from the package based on directory structure and contents + + Looks for .py files and directories containing __init__.py + """ + package_files = [PurePosixPath(x) for x in self.archive.namelist()] + + top_level_paths = { + entry.parents[-2] if len(entry.parents) >= 2 else entry + for entry in package_files + } + + # Check for any top-level python files, as these should also be importable + importable_files = { + file.stem + for file in top_level_paths + if file.suffix == ".py" and not file.name.startswith("_") + } + + # Check for any top-level paths that contain an __init__.py + importable_dirs = { + directory.name + for directory in top_level_paths + if any(( + file.name == "__init__.py" and file.parent == directory + for file in package_files + )) + } + + # TODO: Any other/better methods for determining importable names? + # This seems to produce a fair amount of correct values, but also a fair number of duplicates across packages + + importable = importable_files | importable_dirs + return {x for x in importable if x} + + def get_files(self) -> list[PackageFile]: + """Gets a list of files in the archive along with their mime types and magic string + + The "magic string" is the output of running libmagic on the file, hence the name "magic" string + Not that it is derived through unspecified means + """ + files: list[PackageFile] = [] + for file in self.archive.namelist(): + # Needed to change comprehension to loop+add in order to support exception handling + with suppress(zipfile.BadZipFile, zlib.error): + raw_data = self.archive.read(file) + + try: + mime_type = magic.from_buffer(raw_data, mime=True) + magic_string = magic.from_buffer(raw_data) + except magic.MagicException: + mime_type = None + magic_string = None + files.append(PackageFile( + file_path=PurePosixPath(file), + mime_type=mime_type, + magic_string=magic_string, + )) + + return files + + +def parse_package(name: str) -> Package | None: + """Creates a Package object for the package of the specified name + + Downloads the package's wheel files, parses the imports and records the file contents + Parses the result into a Package object which can be inserted into the database + + Needs to be a standalone function (callable) to be used with concurrent.futures + """ + try: + pypi_package = PyPIPackage(name) + package_info = pypi_package.fetch_metadata() + + package = Package( + package_name=name, + last_serial=package_info["last_serial"], + ) + + wheel_files = more_itertools.peekable(pypi_package.fetch_wheels()) + if not wheel_files: + return None + + imports: set[str] = set() + files: dict[PurePosixPath, PackageFile] = {} # Dict used for deduplication + for wheel in wheel_files: + with wheel: + imports.update(wheel.get_imports()) + for pkg_file in wheel.get_files(): + # Uses setdefault to only save the first occurrence + # If a file with the given path already exists, it won't be overwritten + files.setdefault(pkg_file.file_path, pkg_file) + package.imports = [PackageImport(import_as=x) for x in imports] + package.files = list(files.values()) + + return package + + # If we can't access the data, skip for now and we'll try again later + except (requests.exceptions.ConnectionError, requests.exceptions.RequestException): + return None + + +def main(): + parser = argparse.ArgumentParser( + description="Create Python imports DB from PyPI packages", + ) + parser.add_argument( + "-o", "--output", + required=False, + type=Path, default=Path("PyPIPackageDB.db"), + help="Path of output (database) file to create. Defaults to \"PyPIPackageDB.db\" in the current working directory", + ) + parser.add_argument( + "-v", "--version", + type=int, required=True, + help="Version marker for the database to keep track of changes", + ) + args = parser.parse_args() + + # Ask it to send the response as JSON + # If we don't set the "Accept" header this way, it will respond with HTML instead of JSON + json_headers = { + "Accept": "application/vnd.pypi.simple.v1+json", + } + with requests.get(PYPI_INDEX_URL, headers=json_headers) as web_request: + catalog_info = web_request.json() + package_list = { + entry["name"]: entry["_last-serial"] + for entry in catalog_info["projects"] + } + + pypi_db = PyPIDatabase.create_database(args.output, exist_ok=True) + with pypi_db.session() as session: + # Remove any outdated packages + with session.begin(): + to_update = more_itertools.peekable(( + package + for package in session.exec(select(Package)) + if package_list.get(package.package_name, package.last_serial) != package.last_serial + )) + progress_iter = tqdm( + to_update, + desc="Removing outdated packages", + colour="red", unit="Package", + disable=not to_update, + ) + session.bulk_delete(progress_iter) + + # noinspection PyTypeChecker, Pydantic + saved_packages: set[str] = set(session.exec(select(Package.package_name))) + to_update = set(package_list.keys()) - saved_packages + + # Get new packages and add to the database + TRANSACTION_SIZE = 250 + with BoundedThreadPoolExecutor() as pool: + worker_tasks = ( + functools.partial(parse_package, name) + for name in to_update + ) + futures = pool.bounded_run(worker_tasks) + + progress_iter = tqdm( + futures, + total=len(to_update), + desc="Scraping Packages", colour="blue", + unit="Package", + position=None, leave=None, + disable=not to_update, + ) + for chunk in yet_more_itertools.chunked_iter(progress_iter, TRANSACTION_SIZE): + with session.begin(): + packages = (pkg for future in chunk if (pkg := future.result())) + session.add_all(packages) + + # Reset the metadata if it already exists + # Set version + with session.begin(): + session.exec(delete(Metadata)) + session.add(Metadata( + version=args.version, + format="PyPI", + timestamp=int(datetime.now().timestamp()), + )) + + +if __name__ == "__main__": + main() diff --git a/dataset-generation/Create_Linux_DB.py b/dataset-generation/deprecated/Create_Linux_DB.py similarity index 100% rename from dataset-generation/Create_Linux_DB.py rename to dataset-generation/deprecated/Create_Linux_DB.py diff --git a/dataset-generation/Create_Maven_DB.py b/dataset-generation/deprecated/Create_Maven_DB.py similarity index 100% rename from dataset-generation/Create_Maven_DB.py rename to dataset-generation/deprecated/Create_Maven_DB.py diff --git a/dataset-generation/Create_PyPI_DB.py b/dataset-generation/deprecated/Create_PyPI_DB.py similarity index 100% rename from dataset-generation/Create_PyPI_DB.py rename to dataset-generation/deprecated/Create_PyPI_DB.py diff --git a/dataset-generation/generate_rocksdb.rs b/dataset-generation/deprecated/generate_rocksdb.rs similarity index 100% rename from dataset-generation/generate_rocksdb.rs rename to dataset-generation/deprecated/generate_rocksdb.rs diff --git a/dataset-generation/requirements.txt b/dataset-generation/requirements.txt new file mode 100644 index 0000000..e4b9960 --- /dev/null +++ b/dataset-generation/requirements.txt @@ -0,0 +1,24 @@ +#Requires python version 3.10+ + +requests +httpx +aiohttp +beautifulsoup4 +tqdm +sqlmodel +more-itertools +methodtools +natsort +python-debian +angr +pydemumble #Unvaiable for versions older than 3.10 +typing-extensions +backports.strenum; python_version < "3.11" + +#Python-magic-bin should include libmagic for windows, but is not available for linux +#Mac/Linux need to seperately install libmagic using apt, brew, etc. +python-magic-bin; platform_system == "Windows" +python-magic; platform_system == "Linux" or platform_system == "Darwin" + +dapper-python +dapper-python[dataset-generation] \ No newline at end of file diff --git a/python/dapper_python/databases_v2/database.py b/python/dapper_python/databases_v2/database.py new file mode 100644 index 0000000..5f22a3a --- /dev/null +++ b/python/dapper_python/databases_v2/database.py @@ -0,0 +1,279 @@ +from __future__ import annotations + +import re +import sqlite3 +import warnings +import more_itertools + +from pathlib import Path, PurePath +from enum import Flag, auto +from abc import ABC +from functools import cached_property +from contextlib import suppress + +from sqlmodel import SQLModel, Field as SQLField, Session as BaseSession +from sqlmodel import create_engine, text, delete, bindparam +from sqlalchemy import Engine, Connection, PoolProxiedConnection +from sqlalchemy import event +from sqlalchemy.types import TypeDecorator, String +from sqlalchemy.inspection import inspect + +from collections.abc import Iterable, Iterator, Generator +from typing import ClassVar, TypeVar, Type, Generic, Any +from typing import Union, Optional +from typing_extensions import Self + +from dapper_python.utils import yet_more_itertools + +ModelType = TypeVar('ModelType', bound=Type[SQLModel]) +PathType = TypeVar("PathType", bound=PurePath) + + +class RelationshipWarning(UserWarning): + ... + + +class Session(BaseSession): + """Subclass of SQLModel's Session which provides a convenience function for fast bulk insertion""" + + class AutoCommitFlags(Flag): + NONE = auto() + FLUSH_BATCH = auto() + COMMIT_BATCH = auto() + FLUSH_END = auto() + COMMIT_END = auto() + + @cached_property + def _max_params(self) -> int: + """Determine the maximum number of parameters for a prepared statement based on the sql engine + dialect + + Used to calculate maximum batching size for certain operations + """ + engine = self.get_bind() + match engine.dialect.name: + case "sqlite": + # Need to get raw connection/cursor since using session or connection starts a transaction + cursor = engine.raw_connection().cursor() + cursor.execute("SELECT sqlite_version()") + version, *_ = cursor.fetchmany()[0] + version = tuple(int(x) for x in re.split(r"[._\-]", version) if x.isdigit()) + + # SQLite versions older than 3.32.0 have a maximum parameter limit of 999 + # Whereas newer versions have a maximum limit of 32766 + # See for further detail: https://www.sqlite.org/limits.html#max_variable_number + if version < (3, 32, 0): + return 999 + else: + return 32_766 + + case _: + # Seems doable for most SQL backends. Revisit if this ever encounters problems + return 32_766 + + # From testing, there doesn't seem to be much difference between flushing each batch vs doing it all at the end + # So we might was well periodically flush it while batching + def bulk_insert(self, items: Iterable[SQLModel], *, + batch_size: int = 50_000, + auto_commit: AutoCommitFlags = AutoCommitFlags.FLUSH_BATCH, + ) -> None: + """Convenience function for faster insertion of bulk data + + IMPORTANT: Will only insert data for the model table itself, but not any relationships/linked tables + So only usable if the data is stored in a single table + + Takes an iterable of SQLModel objects and inserts them in batches of size batch_size + Faster than inserting single objects at a time individually + However, comes with the caveat that all the models in the iterable must be of the same type + """ + first, items = more_itertools.spy(items, n=1) + if not first: + return + model_type = type(first[0]) + + with suppress(AttributeError): + if model_type.__mapper__.relationships: + warnings.warn( + f"Class {model_type} has relationships: bulk_insert will not insert them", + category=RelationshipWarning, + stacklevel=2, + ) + + items = yet_more_itertools.enforce_single_type(items) + for batch in yet_more_itertools.chunked_iter(items, batch_size): + mappings = (x.model_dump() for x in batch) + # noinspection PyTypeChecker + self.bulk_insert_mappings(model_type, mappings) + + if self.AutoCommitFlags.FLUSH_BATCH in auto_commit: + self.flush() + if self.AutoCommitFlags.COMMIT_BATCH in auto_commit: + self.commit() + + if self.AutoCommitFlags.FLUSH_END in auto_commit: + self.flush() + if self.AutoCommitFlags.COMMIT_END in auto_commit: + self.commit() + + def bulk_delete(self, items: Iterable[SQLModel], *, + batch_size: int = 1000, + auto_commit: AutoCommitFlags = AutoCommitFlags.FLUSH_BATCH, + ) -> None: + """Convenience function for faster removal of bulk data + + Takes an iterable of SQLModel objects and removes them in batches of size batch_size + Faster than removing single objects at a time individually. + However, comes with the caveat that all the models in the iterable must be of the same type + """ + batch_size = min(batch_size, self._max_params) + + first, items = more_itertools.spy(items, n=1) + if not first: + return + model_type = type(first[0]) + + # SQLModel provides bulk_insert_mappings for adding items, but there's no equivalent for bulk-removing items + # So we'll need a workaround to bulk-remove + primary_key = inspect(model_type).primary_key + if len(primary_key) != 1: + raise ValueError(f"Only supports bulk removal for non-compound primary keys {primary_key}") + primary_key = primary_key[0] + primary_key_name = primary_key.name + + items = yet_more_itertools.enforce_single_type(items) + for batch in yet_more_itertools.chunked_iter(items, batch_size): + values_to_remove = [getattr(obj, primary_key_name) for obj in batch] + stmt = delete(model_type).where(primary_key.in_(bindparam("pks", expanding=True))) + self.exec(stmt, params={"pks": values_to_remove}) + + if self.AutoCommitFlags.FLUSH_BATCH in auto_commit: + self.flush() + if self.AutoCommitFlags.COMMIT_BATCH in auto_commit: + self.commit() + + if self.AutoCommitFlags.FLUSH_END in auto_commit: + self.flush() + if self.AutoCommitFlags.COMMIT_END in auto_commit: + self.commit() + + +class BaseDatabase(ABC): + __registered_models: ClassVar[list[Type[SQLModel]]] = [] + + @classmethod + def register_model(cls, _model: ModelType) -> ModelType: + """Registers an SQLModel class to be used with the database of the subclass""" + if _model not in cls.__registered_models: + cls.__registered_models.append(_model) + return _model + + def __init_subclass__(cls): + """Each subclass needs its own registered models, ensure the list is separate""" + super().__init_subclass__() + cls.__registered_models = [] + + def __init__(self, engine: Engine) -> None: + self._engine = engine + + # Initialize the database + for model in self.__registered_models: + model.metadata.create_all(self._engine) + + # Any models registered to the base class will be set for all derived classes + for model in BaseDatabase.__registered_models: + model.metadata.create_all(self._engine) + + @classmethod + def create_database(cls, db: Union[Path, str], *, exist_ok: bool = False) -> Self: + """Create a new database file at the provided path and connects to it + + If the file already exists, it will raise a FileExistsError unless exist_ok is True + """ + if not isinstance(db, Path): + db = Path(db) + if db.exists() and not exist_ok: + raise FileExistsError(f"Database file already exists at {db}") + + db_uri = f"sqlite:///{db.absolute().as_posix()}" + engine = create_engine(db_uri, echo=False) + event.listen(engine, "connect", cls._sqlite_pragma_on_connect) + return cls(engine) + + @classmethod + def open_database(cls, db: Union[Path, str]) -> Self: + """Connect to an existing database file at the provided path + + If the file does not exist, it will raise a FileNotFoundError + """ + if not isinstance(db, Path): + db = Path(db) + if not db.exists(): + raise FileNotFoundError(f"No database file exists at {db}") + + db_uri = f"sqlite:///{db.absolute().as_posix()}" + engine = create_engine(db_uri, echo=False) + event.listen(engine, "connect", cls._sqlite_pragma_on_connect) + return cls(engine) + + @staticmethod + def _sqlite_pragma_on_connect(dbapi_conn: sqlite3.Connection, conn_record: Any) -> None: + """Enables certain features of SQL when connecting""" + cursor = dbapi_conn.cursor() + # We need to enable foreign_keys to ensure cascading deletes work properly + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + def session(self) -> Session: + return Session(self._engine) + + def connection(self) -> Connection: + return self._engine.connect() + + def raw_connection(self) -> PoolProxiedConnection: + return self._engine.raw_connection() + + +@BaseDatabase.register_model +class Metadata(SQLModel, table=True): + """Should only have a single row to store metadata about the database""" + __tablename__ = "dataset_version" + + version: int = SQLField(primary_key=True) + format: str + timestamp: int + + +class SQLPath(TypeDecorator, Generic[PathType]): + """Mapper to allow storing and retrieving path objects in SQL databases via ORM + + Can provide type (i.e PurePath, PurePosixPath, etc.) to control how paths are constructed when retrieved + Paths will always be stored as posix paths in the database + + Sample Usage: + from sqlmodel import Column + Class MyModel(SQLModel, table=True): + ... + path: PurePosixPath = SQLField(sa_column=Column(SQLPath(PurePosixPath))) + """ + impl = String() + + def __init__(self, path_cls: type[PathType] = PurePath, *args, **kwargs): + super().__init__(*args, **kwargs) + + if not isinstance(path_cls, type) or not issubclass(path_cls, PurePath): + raise TypeError("path_cls must be a subclass of PurePath") + self._path_cls = path_cls + + def process_bind_param(self, value, dialect) -> Optional[str]: + if value is None: + return None + elif isinstance(value, str): + value = self._path_cls(value) + + if not isinstance(value, PurePath): + raise TypeError(f"Expected PurePath or subclass, got {type(value)}") + return value.as_posix() + + def process_result_value(self, value, dialect) -> Optional[PathType]: + if value is None: + return None + return self._path_cls(value) diff --git a/python/dapper_python/databases_v2/linux_db.py b/python/dapper_python/databases_v2/linux_db.py new file mode 100644 index 0000000..5fd63c8 --- /dev/null +++ b/python/dapper_python/databases_v2/linux_db.py @@ -0,0 +1,64 @@ +# Using __future__ annotations breaks SQLModel ORM relationships, don't use it here +# See https://github.com/fastapi/sqlmodel/discussions/900 for issue discussion + +from pathlib import PurePosixPath +from sqlmodel import SQLModel, Field as SQLField, Column +from sqlalchemy import Engine, text + +from typing import Optional + +from dapper_python.databases_v2.database import BaseDatabase +from dapper_python.databases_v2.database import SQLPath +from dapper_python.normalize import normalize_file_name + + +# This needs to be placed before the models in order for the registration decorator to work +class LinuxDatabase(BaseDatabase): + def __init__(self, engine: Engine) -> None: + super().__init__(engine) + + # Need to create views manually since SQLModel does not have native support for views + # TODO: See if there's some better way to create this without writing raw SQL + with self._engine.connect() as conn: + with conn.begin(): + create_view_cmd = """ + CREATE VIEW + IF NOT EXISTS v_package_files + AS + SELECT file_name, normalized_file_name, file_path, package_files.package_name AS package_name, full_package_name, package_sources.package_name AS source_package_name + FROM package_files + LEFT OUTER JOIN package_sources + ON package_files.package_name = package_sources.bin_package + """ + conn.execute(text(create_view_cmd)) + + +# Database Tables +@LinuxDatabase.register_model +class PackageFile(SQLModel, table=True): + __tablename__ = "package_files" + + id: Optional[int] = SQLField(default=None, nullable=False, primary_key=True) + file_name: str = SQLField(default=None, index=True) + normalized_file_name: str = SQLField(index=True, default=None) + file_path: PurePosixPath = SQLField(sa_column=Column(SQLPath(PurePosixPath))) + package_name: str + full_package_name: str + + # Normalized file name and file name automatically constructed from file_path if not provided + def model_post_init(self, __context) -> None: + # Automatically get and normalize the filename + self.file_path = PurePosixPath(self.file_path) + if self.file_name is None: + self.file_name = self.file_path.name + if self.normalized_file_name is None: + self.normalized_file_name = str(normalize_file_name(self.file_name)) + + +@LinuxDatabase.register_model +class PackageSource(SQLModel, table=True): + __tablename__ = "package_sources" + + id: Optional[int] = SQLField(default=None, nullable=False, primary_key=True) + package_name: str + bin_package: str = SQLField(index=True) diff --git a/python/dapper_python/databases_v2/maven_db.py b/python/dapper_python/databases_v2/maven_db.py new file mode 100644 index 0000000..2b0fc95 --- /dev/null +++ b/python/dapper_python/databases_v2/maven_db.py @@ -0,0 +1,56 @@ +# Using __future__ annotations breaks SQLModel ORM relationships, don't use it here +# See https://github.com/fastapi/sqlmodel/discussions/900 for issue discussion + +from sqlalchemy import Engine, text +from sqlmodel import SQLModel, Field as SQLField, Relationship + +from typing import Optional + +from dapper_python.databases_v2.database import BaseDatabase + + +class MavenDatabase(BaseDatabase): + def __init__(self, engine: Engine) -> None: + super().__init__(engine) + + # Need to create views manually since SQLModel does not have native support for views + # TODO: See if there's some better way to create this without writing raw SQL + with self._engine.connect() as conn: + with conn.begin(): + # User-facing view for files which hides the backend tracking logic + create_view_cmd = """ + CREATE VIEW + IF NOT EXISTS v_package_files + AS + SELECT package_name, file_name + FROM packages + JOIN package_files + ON packages.id = package_files.package_id + """ + conn.execute(text(create_view_cmd)) + + +# Database Tables +@MavenDatabase.register_model +class Package(SQLModel, table=True): + __tablename__ = "packages" + + id: Optional[int] = SQLField(default=None, nullable=False, primary_key=True) + package_name: str = SQLField(index=True) + group_id: str + timestamp: int + + # Relationships + files: list["PackageFile"] = Relationship(back_populates="package") + + +@MavenDatabase.register_model +class PackageFile(SQLModel, table=True): + __tablename__ = "package_files" + + id: Optional[int] = SQLField(default=None, nullable=False, primary_key=True) + package_id: Optional[int] = SQLField(default=None, nullable=False, foreign_key="packages.id", ondelete="CASCADE", index=True) + file_name: str = SQLField(index=True) + + # Relationships + package: "Package" = Relationship(back_populates="files") diff --git a/python/dapper_python/databases_v2/mingw_db.py b/python/dapper_python/databases_v2/mingw_db.py new file mode 100644 index 0000000..5c80314 --- /dev/null +++ b/python/dapper_python/databases_v2/mingw_db.py @@ -0,0 +1,151 @@ +# Using __future__ annotations breaks SQLModel ORM relationships, don't use it here +# See https://github.com/fastapi/sqlmodel/discussions/900 for issue discussion + +from pathlib import PurePosixPath +from pydantic import ConfigDict +from sqlmodel import SQLModel, Field as SQLField, Relationship, Column + +from typing import Optional + +from dapper_python.databases_v2.database import BaseDatabase +from dapper_python.databases_v2.database import SQLPath +from dapper_python.normalize import normalize_file_name + + +# This needs to be placed before the models in order for the registration decorator to work +class MinGWDatabase(BaseDatabase): + ... + + +# Database Tables +@MinGWDatabase.register_model +class Package(SQLModel, table=True): + model_config = ConfigDict(extra="allow") + __tablename__ = "packages" + + id: Optional[int] = SQLField(default=None, nullable=False, primary_key=True) + package_name: str = SQLField(index=True) + package_version: str + + # Relationships + source_files: list["SourceFile"] = Relationship(back_populates="package") + package_files: list["PackageFile"] = Relationship(back_populates="package") + + +# TODO: Looking at the number of times this pattern for XYZ_File is used, could be useful to make a common base class +# But there's also benefits to keeping them separate, changing one database doesn't impact another +# E.G. If we change the MinGW database, we aren't forced to change the Python or Linux database due to match +@MinGWDatabase.register_model +class PackageFile(SQLModel, table=True): + """File in the downloaded package (ie. what would actually appear on the system once installed) + + Not to be confused with source files we build the symbol database from + """ + model_config = ConfigDict(extra="allow") + __tablename__ = "package_files" + + id: Optional[int] = SQLField(default=None, nullable=False, primary_key=True) + package_id: Optional[int] = SQLField(default=None, nullable=False, foreign_key="packages.id", ondelete="CASCADE", index=True) + + file_name: str = SQLField(default=None) + normalized_file_name: str = SQLField(default=None) + file_path: PurePosixPath = SQLField(sa_column=Column(SQLPath(PurePosixPath))) + mime_type: str + magic_string: str + + # Relationships + package: "Package" = Relationship(back_populates="package_files") + + # Normalized file name and file name automatically constructed from file_path if not provided + def model_post_init(self, __context) -> None: + # Automatically get and normalize the filename + self.file_path = PurePosixPath(self.file_path) + if self.file_name is None: + self.file_name = self.file_path.name + if self.normalized_file_name is None: + self.normalized_file_name = str(normalize_file_name(self.file_name)) + + +@MinGWDatabase.register_model +class SourceFile(SQLModel, table=True): + """File in the source code of a package (ie. in the source tarball) + + Not to be confused with files that are part of the package itself and installed on a system when in use + """ + model_config = ConfigDict(extra="allow") + __tablename__ = "source_files" + + id: Optional[int] = SQLField(default=None, nullable=False, primary_key=True) + package_id: Optional[int] = SQLField(default=None, nullable=False, foreign_key="packages.id", ondelete="CASCADE", index=True) + file_name: str = SQLField(default=None) + normalized_file_name: str = SQLField(default=None) + file_path: PurePosixPath = SQLField(sa_column=Column(SQLPath(PurePosixPath))) + + # Relationships + package: "Package" = Relationship(back_populates="source_files") + functions: list["FunctionSymbol"] = Relationship(back_populates="file") + + # Normalized file name and file name automatically constructed from file_path if not provided + def model_post_init(self, __context) -> None: + # Automatically get and normalize the filename + self.file_path = PurePosixPath(self.file_path) + if self.file_name is None: + self.file_name = self.file_path.name + if self.normalized_file_name is None: + self.normalized_file_name = str(normalize_file_name(self.file_name)) + + +# Not currently implemented/used, but may want to add in the future +# Are not registered (and therefore no table created) with the database +class ClassSymbol(SQLModel, table=False): + __tablename__ = "class_symbols" + + id: Optional[int] = SQLField(default=None, nullable=False, primary_key=True) + file_id: Optional[int] = SQLField(default=None, nullable=False, foreign_key="source_files.id", ondelete="CASCADE", index=True) + + +class StructSymbol(SQLModel, table=False): + __tablename__ = "struct_symbols" + + id: Optional[int] = SQLField(default=None, nullable=False, primary_key=True) + file_id: Optional[int] = SQLField(default=None, nullable=False, foreign_key="source_files.id", ondelete="CASCADE", index=True) + + +@MinGWDatabase.register_model +class FunctionSymbol(SQLModel, table=True): + __tablename__ = "function_symbols" + + id: Optional[int] = SQLField(default=None, nullable=False, primary_key=True) + file_id: Optional[int] = SQLField(default=None, nullable=False, foreign_key="source_files.id", ondelete="CASCADE", index=True) + + return_type: str + symbol_name: str = SQLField(index=True) + qualified_symbol_name: str = SQLField(index=True) + params: str + full_signature: str = SQLField(index=True) + source_text: str + + in_binary: Optional[bool] = SQLField(default=None) + + # Relationships + file: "SourceFile" = Relationship(back_populates="functions") + + +# These are used for analysis, but not currently saved to the database, instead being dumped to JSON +# They are set up so that they could be added to the database, but are currently not for space reasons and lack of uses +# Are not registered (and therefore no table created) with the database +class PreprocessDefine(SQLModel, table=False): + __tablename__ = "preproc_defs" + + id: Optional[int] = SQLField(default=None, nullable=False, primary_key=True) + file_id: Optional[int] = SQLField(default=None, nullable=False, foreign_key="source_files.id", ondelete="CASCADE", index=True) + name: str + value: str + + +class StringLiteral(SQLModel, table=False): + __tablename__ = "string_literals" + + id: Optional[int] = SQLField(default=None, nullable=False, primary_key=True) + file_id: Optional[int] = SQLField(default=None, nullable=False, foreign_key="source_files.id", ondelete="CASCADE", index=True) + value: str diff --git a/python/dapper_python/databases_v2/python_db.py b/python/dapper_python/databases_v2/python_db.py new file mode 100644 index 0000000..800b1af --- /dev/null +++ b/python/dapper_python/databases_v2/python_db.py @@ -0,0 +1,96 @@ +# Using __future__ annotations breaks SQLModel ORM relationships, don't use it here +# See https://github.com/fastapi/sqlmodel/discussions/900 for issue discussion + +from pathlib import PurePosixPath +from sqlalchemy import Engine, text +from sqlmodel import SQLModel, Field as SQLField, Relationship, Column + +from typing import Optional + +from dapper_python.databases_v2.database import BaseDatabase +from dapper_python.databases_v2.database import SQLPath +from dapper_python.normalize import normalize_file_name + + +class PyPIDatabase(BaseDatabase): + def __init__(self, engine: Engine) -> None: + super().__init__(engine) + + # Need to create views manually since SQLModel does not have native support for views + # TODO: See if there's some better way to create this without writing raw SQL + with self._engine.connect() as conn: + with conn.begin(): + # User-facing view for imports which hides the backend tracking logic + create_view_cmd = """ + CREATE VIEW + IF NOT EXISTS v_package_imports + AS + SELECT package_name, import_as + FROM packages + JOIN package_imports + ON packages.id = package_imports.package_id + """ + conn.execute(text(create_view_cmd)) + + # User-facing view for files which hides the backend tracking logic + create_view_cmd = """ + CREATE VIEW + IF NOT EXISTS v_package_files + AS + SELECT package_name, normalized_file_name, file_name, file_path, mime_type, magic_string + FROM packages + JOIN package_files + ON packages.id = package_files.package_id + """ + conn.execute(text(create_view_cmd)) + + +# Database Tables +@PyPIDatabase.register_model +class Package(SQLModel, table=True): + __tablename__ = "packages" + + id: Optional[int] = SQLField(default=None, nullable=False, primary_key=True) + package_name: str + last_serial: int + + # Relationships + imports: list["PackageImport"] = Relationship(back_populates="package") + files: list["PackageFile"] = Relationship(back_populates="package") + + +@PyPIDatabase.register_model +class PackageImport(SQLModel, table=True): + __tablename__ = "package_imports" + + id: Optional[int] = SQLField(default=None, nullable=False, primary_key=True) + package_id: Optional[int] = SQLField(default=None, nullable=False, foreign_key="packages.id", index=True, ondelete="CASCADE") + import_as: str = SQLField(index=True) + + # Relationships + package: "Package" = Relationship(back_populates="imports") + + +@PyPIDatabase.register_model +class PackageFile(SQLModel, table=True): + __tablename__ = "package_files" + + id: Optional[int] = SQLField(default=None, nullable=False, primary_key=True) + package_id: Optional[int] = SQLField(default=None, nullable=False, foreign_key="packages.id", index=True, ondelete="CASCADE") + file_name: str = SQLField(default=None) + normalized_file_name: str = SQLField(default=None) + file_path: PurePosixPath = SQLField(sa_column=Column(SQLPath(PurePosixPath))) + mime_type: str + magic_string: str + + # Relationships + package: "Package" = Relationship(back_populates="files") + + # Normalized file name and file name automatically constructed from file_path if not provided + def model_post_init(self, __context) -> None: + # Automatically get and normalize the filename + self.file_path = PurePosixPath(self.file_path) + if self.file_name is None: + self.file_name = self.file_path.name + if self.normalized_file_name is None: + self.normalized_file_name = str(normalize_file_name(self.file_name)) diff --git a/python/dapper_python/dataset_generation/README b/python/dapper_python/dataset_generation/README new file mode 100644 index 0000000..0e06a32 --- /dev/null +++ b/python/dapper_python/dataset_generation/README @@ -0,0 +1,2 @@ +The "dataset-generation" sub-package is intended to be an optional addon used primarily when generating datasets +It can be utilized by installing the "dataset_generation" dependencies via pip install dapper-python[dataset-generation] \ No newline at end of file diff --git a/python/dapper_python/dataset_generation/datatypes/exceptions.py b/python/dapper_python/dataset_generation/datatypes/exceptions.py new file mode 100644 index 0000000..e3c7209 --- /dev/null +++ b/python/dapper_python/dataset_generation/datatypes/exceptions.py @@ -0,0 +1,17 @@ +from __future__ import annotations + + +class ParseError(Exception): + """Exception for when and error occurs during parsing and is unable to continue/complete""" + + def __init__(self, *args, text: str) -> None: + """ + :param text: The text that was being parsed when the error occurred + """ + super().__init__(*args) + self.text = text + + def __str__(self) -> str: + base = super().__str__() + addon = f"While parsing: \"{self.text}\"" + return base + " " + addon diff --git a/python/dapper_python/dataset_generation/parsing/cpp.py b/python/dapper_python/dataset_generation/parsing/cpp.py new file mode 100644 index 0000000..2cbb9bd --- /dev/null +++ b/python/dapper_python/dataset_generation/parsing/cpp.py @@ -0,0 +1,355 @@ +from __future__ import annotations + +import re +import itertools + +import tree_sitter_cpp as ts_cpp + +from dataclasses import dataclass, field +from contextlib import suppress +from enum import Enum, auto + +from tree_sitter import Language, Parser, Query, QueryCursor +from tree_sitter import Tree, Node + +from collections.abc import Generator +from typing import ClassVar, Literal +from typing import Union, Optional +from typing_extensions import Self + +from dapper_python.dataset_generation.utils.ast import ancestors +from dapper_python.dataset_generation.datatypes.exceptions import ParseError + + +# In the generate__db, we try to use the database SQL models when possible +# These classes are predominantly the same as those of the databases, +# But are kept separate as they don't belong only to a single database +# TODO: Any better way to avoid duplication? + + +@dataclass(frozen=True) +class FunctionSymbol: + return_type: str + symbol_name: str = field(hash=False) + qualified_symbol_name: str + param_list: list[str] + modifiers: list[str] = field(default_factory=list) + + source_text: Optional[str] = field(default=None, hash=False, compare=False) + + @property + def params(self) -> str: + return ", ".join(self.param_list) + + @property + def full_signature(self) -> str: + if self.modifiers: + return f"{self.return_type} {self.qualified_symbol_name}({self.params}) {' '.join(self.modifiers)}" + else: + return f"{self.return_type} {self.qualified_symbol_name}({self.params})" + + +@dataclass(frozen=True) +class PreprocessDefine: + name: str + value: str + + +@dataclass(frozen=True) +class StringLiteral: + value: str + + +@dataclass +class CPPTreeParser: + """Parses tree-sitter AST for C/C++ source code""" + tree: Tree + + @classmethod + def from_source(cls, contents: Union[bytes, bytearray], *, encoding: Literal["utf8", "utf16", "utf16le", "utf16be"] = "utf8") -> Self: + ts_parser = Parser(cls.__CPP_LANG) + tree = ts_parser.parse(contents, encoding=encoding) + return cls(tree) + + def parse_functions(self) -> Generator[FunctionSymbol, None, None]: + """Extracts all function definitions from the source code""" + cursor = QueryCursor(self._FUNCTION_QUERY) + + for pattern, elements in cursor.matches(self.tree.root_node): + fn_definition_node = elements["function_definition"][0] + + # If there are errors, we're unlikely to get a satisfactory result, so skip + error_cursor = QueryCursor(self._ERROR_QUERY) + if error_cursor.matches(fn_definition_node): + continue + + # Tree sitter picks out too many edge cases to filter on preemptively + # Try to parse the "happy path" and skip failures as it was likely misidentified as a function or malformed + with suppress(ParseError, UnicodeDecodeError): + function_parser = CPPFunctionParser(fn_definition_node) + function = function_parser.parse_function() + yield function + + def parse_preproc_defs(self) -> Generator[PreprocessDefine, None, None]: + """Extracts all preprocessor "define" macros from the source code + + This only includes ones that have a value associated with them, defines that do not are ignored + #define PI 3.14159 <- Will be included + #define INCLUDE_GUARD_H <- Will be ignored + """ + cursor = QueryCursor(self._PREPROC_QUERY) + + for pattern, elements in cursor.matches(self.tree.root_node): + with suppress(UnicodeDecodeError): + yield PreprocessDefine( + name=elements["name"][0].text.decode().strip(), + value=elements["value"][0].text.decode().strip(), + ) + + def parse_string_literals(self) -> Generator[StringLiteral, None, None]: + """Extracts all lines from source code that contain string literal(s) + + Extracts an entire expression if the string literal is part of a declaration or expression statement + Individual strings aren't that useful: ["version", "built"] + But as part of an expression could be: cout << "version" << ns.VERSION_NUMBER << "built" << ns.DATE; + """ + cursor = QueryCursor(self._STRING_LITERAL_QUERY) + + for pattern, elements in cursor.matches(self.tree.root_node): + # Exclude any strings from preproc includes + if any(node.type == "preproc_include" for node in ancestors(elements["string"][0])): + continue + + with suppress(StopIteration, UnicodeDecodeError): + parent_node = next(( + x for x in ancestors(elements["string"][0]) + if x.type in ("declaration", "expression_statement") + )) + + yield StringLiteral( + value=parent_node.text.decode().strip(), + ) + + __CPP_LANG: ClassVar = Language(ts_cpp.language()) + + _ERROR_QUERY: ClassVar[Query] = Query( + __CPP_LANG, + "[(ERROR) (MISSING)] @error", + ) + + _FUNCTION_QUERY: ClassVar[Query] = Query( + __CPP_LANG, + """ + ( + function_definition + (type_qualifier)* @type_qualifier + type: (_) @type + declarator: (_) @declarator + ) @function_definition + """, + ) + + _PREPROC_QUERY: ClassVar[Query] = Query( + __CPP_LANG, + """ + ( + preproc_def + name: (_) @name + value: (_) @value + ) + """, + ) + + _STRING_LITERAL_QUERY: ClassVar[Query] = Query( + __CPP_LANG, + """ + (string_literal) @string + """, + ) + + +@dataclass +class CPPFunctionParser: + """Specific sub-parser for tree-sitter AST for C/C++ function definitions""" + fn_node: Node + + class _TypeSpec(Enum): + FUNCTION = auto() + PARAMETER = auto() + AUTO = auto() + + def parse_function(self) -> FunctionSymbol: + """Parses the function definition node and returns a FunctionSymbol""" + try: + # Function return type + return_type = self._parse_type(self.fn_node) + return_type = self.normalize_string(return_type) + + # Function name/qualified name + fn_declarator_cursor = QueryCursor(self._FUNCTION_DECLARATOR_QUERY) + fn_declarator_node = fn_declarator_cursor.matches(self.fn_node)[0][1]["function_declarator"][0] + qualified_name = fn_declarator_node.child_by_field_name("declarator").text.decode() + + # Check if the function is inside any classes/namespaces that would further add onto the name + qualified_ancestors = reversed([ + x for x in ancestors(fn_declarator_node) + if x.type in ("namespace_definition", "class_specifier") + ]) + addon_qualifiers = [ + self.normalize_string(name_node.text.decode()) + for ancestor in qualified_ancestors + if (name_node := ancestor.child_by_field_name("name")) + ] + if addon_qualifiers: + qualified_name = f"{'::'.join(addon_qualifiers)}::{qualified_name}" + qualified_name = self.normalize_string(qualified_name) + + # Sometimes field_identifier instead of identifier when part of a class in a header file + fn_identifier_cursor = QueryCursor(self._FUNCTION_IDENTIFIER_QUERY) + fn_identifier_node = fn_identifier_cursor.matches(fn_declarator_node)[0][1]["identifier"][0] + name = fn_identifier_node.text.decode() + name = self.normalize_string(name) + + # Anything that modifies the function, such as "const" applied to a method + modifiers = [ + self.normalize_string(x.text.decode()) + for x in fn_declarator_node.children + if x.type == "type_qualifier" + ] + + # Function parameters + param_cursor = QueryCursor(self._FUNCTION_PARAMETER_QUERY) + param_nodes = param_cursor.matches(self.fn_node)[0][1]["parameter_list"][0] + param_nodes = [ + x for x in param_nodes.children + if x.type in ("parameter_declaration", "optional_parameter_declaration") + ] + parameters = [ + self.normalize_string(self._parse_type(p_node)) + for p_node in param_nodes + ] + + source_text = self._signature_text(self.fn_node).strip() + source_text = self._MULTI_WHITESPACE_REGEX.sub(" ", source_text).strip() + + return FunctionSymbol( + return_type=return_type, + symbol_name=name, + qualified_symbol_name=qualified_name, + param_list=parameters, + modifiers=modifiers, + source_text=source_text, + ) + except IndexError as e: + raise ParseError(text=self.fn_node.text.decode().strip()) from e + + @classmethod + def _parse_type(cls, node: Node) -> str: + """Parses the given node to extract the type + + Supports: + Function_definition node to extract function return type + Parameter_declaration (or optional_parameter_declaration) to extract parameter type + """ + if node.type == "function_definition": + type_spec = cls._TypeSpec.FUNCTION + identifier_query = Query(cls.__CPP_LANG, "(function_declarator) @declarator") + elif node.type in ("parameter_declaration", "optional_parameter_declaration"): + type_spec = cls._TypeSpec.PARAMETER + identifier_query = Query(cls.__CPP_LANG, "(identifier) @declarator") + else: + raise TypeError(f"Unexpected node type: {node.type}") + identifier_cursor = QueryCursor(identifier_query) + + # Extract the full return type by combining any qualifiers, type, and modifiers that are part of the declarator + qualifiers = [ + x.text.decode().strip() + for x in node.children + if x.type == "type_qualifier" + ] + + base_type = node.child_by_field_name("type").text.decode().strip() + + modifiers = [] + with suppress(IndexError): + decl_node = identifier_cursor.matches(node)[0][1]["declarator"][0] + modifier_chain = reversed(list( + itertools.takewhile(lambda x: x != node, ancestors(decl_node)), + )) + + for modifier in modifier_chain: + literals = [x for x in modifier.children if x.type in ("*", "&", "&&")] + for literal in literals: + modifiers.append(literal.text.decode()) + + # Decay array type to pointer + if modifier.type == "array_declarator": + modifiers.append("*") + + if qualifier := next((x for x in modifier.children if x.type == "type_qualifier"), None): + modifiers.append(qualifier.text.decode()) + + final_type = f"{' '.join(qualifiers)} {base_type}{' '.join(modifiers)}" + final_type = cls.normalize_string(final_type) + + return final_type + + @classmethod + def normalize_string(cls, string: str) -> str: + """Normalizes the given signature to a uniform format""" + for pattern, replacement in cls._CLEANUP_REGEXES.items(): + string = pattern.sub(replacement, string) + return string.strip() + + @classmethod + def _signature_text(cls, fn_node: Node) -> str: + """Extracts the signature of the given function definition while discarding the body + + We want to be able to compare the source that we extracted the information from + But don't actually care about the body contents of the function + """ + if fn_node.type == "declaration": + return fn_node.text.decode().strip() + elif fn_node.type != "function_definition": + raise TypeError(f"Expected function_definition node, got {fn_node.type}") + + function_bytes = fn_node.text + function_start = fn_node.start_byte + + body = fn_node.child_by_field_name("body") + if body is None: + return function_bytes.decode() + + body_start = body.start_byte - function_start + body_end = body.end_byte - function_start + + # Concatenate everything excluding the body + signature_bytes = function_bytes[:body_start] + function_bytes[body_end:] + return signature_bytes.decode().strip() + + __CPP_LANG: ClassVar = Language(ts_cpp.language()) + + _FUNCTION_DECLARATOR_QUERY: ClassVar[Query] = Query( + __CPP_LANG, + "(function_declarator) @function_declarator", + ) + _FUNCTION_IDENTIFIER_QUERY: ClassVar[Query] = Query( + __CPP_LANG, + "[(identifier) (field_identifier) (operator_name)] @identifier", + ) + _FUNCTION_PARAMETER_QUERY: ClassVar[Query] = Query( + __CPP_LANG, + "parameters: (parameter_list) @parameter_list", + ) + + # Used to substitute strings matching the regex (key) with the provided string (value) + # For normalization and consistent formatting in dataset + # Processed in the order they are listed + _CLEANUP_REGEXES: ClassVar[dict[re.Pattern, str]] = { + re.compile(r"\n"): " ", # Remove newlines: "a\nb\nc" -> "a b c" + re.compile(r"\s+(?=[*&])"): "", # Ensure spacing of pointers and references: "int *" or "float &" -> "int*" and "float&" + re.compile(r",(?=\S)"): ", ", # Ensure consistent spacing of lists: "a, b, c, d" -> "a,b, c,d, e" + re.compile(r"(?<=<)\s+|\s+(?=>)"): "", # Ensure no empty spaces around templates: "pair< int, float >" -> "pair" + re.compile(r"\s+"): " ", # Remove multiple sequential spaces: "a b c" -> "a b c" + } + _MULTI_WHITESPACE_REGEX: ClassVar[re.Pattern] = re.compile(r"\s+") diff --git a/python/dapper_python/dataset_generation/utils/archive.py b/python/dapper_python/dataset_generation/utils/archive.py new file mode 100644 index 0000000..864aac0 --- /dev/null +++ b/python/dapper_python/dataset_generation/utils/archive.py @@ -0,0 +1,97 @@ +""" +Based on the kinds of files we're scraping for this project (official package archives from Ubuntu, Debian, etc), +We should hopefully not encounter malicious archives, but we should try ot be as safe as possible anyway +""" + +from __future__ import annotations + +from pathlib import Path +from tarfile import TarFile, TarInfo +from zipfile import ZipFile, ZipInfo + +from typing import Union + + +class SafeTarFile(TarFile): + def safe_extractall(self, path: Union[Path, str], **kwargs) -> None: + """Extracts all archive to a given path + + Does some additional checking to try to prevent malicious tarfile contents from being extracted + Such as files that use absolute paths or paths containing ".." to try and modify files outside the target directory + + Intended as an improved-safety version of extractall() for compatability with older Python versions, + But a better approach is to use a newer Python version which has better security built in to extractall() + However, this library currently supports older python versions that don't have that built in + """ + if isinstance(path, str): + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"No such directory: {path}") + elif not path.is_dir(): + raise NotADirectoryError(f"Not a directory: {path}") + path = path.resolve() + + kwargs.pop("member", None) + kwargs.pop("path", None) + for member in self.getmembers(): + output_path = path.joinpath(member.name).resolve() + + if not output_path.is_relative_to(path): + # Do not extract any file that would be placed outside the provided root path + continue + + if member.issym() or member.islnk(): + link_target = Path(member.linkname) + if link_target.is_absolute(): + # Exclude all absolute symlinks + continue + + link_target = output_path.parent.joinpath(link_target).resolve() + if not link_target.is_relative_to(path): + # Exclude any symlink whose target would be outside the provided root path + continue + + self.extract(member, path=path, **kwargs) + + +class SafeZipFile(ZipFile): + def safe_extractall(self, path: Union[Path, str], **kwargs) -> None: + """Extracts all archive members to a given path + + Does some additional checking to try to prevent malicious tarfile contents from being extracted + Such as files that use absolute paths or paths containing ".." to try and modify files outside the target directory + """ + if isinstance(path, str): + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"No such directory: {path}") + elif not path.is_dir(): + raise NotADirectoryError(f"Not a directory: {path}") + path = path.resolve() + + kwargs.pop("member", None) + kwargs.pop("path", None) + for member in self.namelist(): + output_path = path.joinpath(member).resolve() + if not output_path.is_relative_to(path): + # Do not extract any file that would be placed outside the provided root path + continue + + member_info = self.getinfo(member) + if self.is_symlink(member_info): + link_target = Path(self.open(member_info).read().decode("utf-8")) + if link_target.is_absolute(): + # Exclude all absolute symlinks + continue + + link_target = output_path.parent.joinpath(link_target).resolve() + if not link_target.is_relative_to(path): + # Exclude any symlink whose target would be outside the provided root path + continue + + self.extract(member, path=path, **kwargs) + + @staticmethod + def is_symlink(zip_info: ZipInfo) -> bool: + mode = (zip_info.external_attr >> 16) & 0xFFFF + return (mode & 0o170000) == 0o120000 diff --git a/python/dapper_python/dataset_generation/utils/ast.py b/python/dapper_python/dataset_generation/utils/ast.py new file mode 100644 index 0000000..da89c20 --- /dev/null +++ b/python/dapper_python/dataset_generation/utils/ast.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import tree_sitter + +from collections.abc import Generator + + +def ancestors(node: tree_sitter.Node) -> Generator[tree_sitter.Node, None, None]: + """Yields all ancestor nodes of the provided node + + Order from closest to furthest ancestor + parent -> grandparent -> ... -> root + """ + current = node.parent + while current: + yield current + current = current.parent + + +def descendants(node: tree_sitter.Node) -> Generator[tree_sitter.Node, None, None]: + """Yields all descendant nodes of the provided node""" + for child in node.children: + yield child + yield from descendants(child) diff --git a/python/dapper_python/dataset_generation/utils/futures.py b/python/dapper_python/dataset_generation/utils/futures.py new file mode 100644 index 0000000..163e9ca --- /dev/null +++ b/python/dapper_python/dataset_generation/utils/futures.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import math +import itertools +import concurrent.futures + +from concurrent.futures import Future, Executor +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor +from concurrent.futures import FIRST_COMPLETED + +from collections.abc import Iterable, Callable, Generator +from typing import TypeVar, Type +from typing import ClassVar +from typing import Union, Optional + +T = TypeVar("T") +U = TypeVar("U") +R = TypeVar("R") +ExceptionGroupType = Union[Type[Exception], tuple[Type[Exception], ...]] + + +class BoundedSubmissionMixin(Executor): + """Mixin class for concurrent.futures executor pools to allow bounded submission of tasks""" + + def bounded_run(self, _iter: Iterable[Callable[[], R]], *, + bound: Optional[int] = None) -> Generator[Future[R], None, None]: + """Submits tasks to the pool with a bound on the number of tasks submitted at any given time + Yields future objects as they complete + + Allows for processing a large number of tasks without upfront initialization of all futures in memory + Creating from generator can mean only a handful of futures are created at a time + + Equivalent functionality exists in python 3.14+ using Executor.map() with buffersize provided + However, this does not exist in older versions; hence this implementation + """ + if bound is None: + pool_size = getattr(self, "_pool_size", self._FALLBACK_SUBMISSION_WORKERS) + bound = int(math.ceil(pool_size * self.SUBMISSION_BOUND_RATIO)) + + futures = set() + it = iter(_iter) + + for _callable in itertools.islice(it, bound): + futures.add(self.submit(_callable)) + + while futures: + done, futures = concurrent.futures.wait(futures, return_when=FIRST_COMPLETED) + for _callable in itertools.islice(it, len(done)): + futures.add(self.submit(_callable)) + yield from done + + # Ratio of number of tasks to submit compared to the number of workers + # If there are 6 workers and the ratio is 2, then 12 tasks will be in the queue at any given time + SUBMISSION_BOUND_RATIO: ClassVar[float] = 2 + # Number of assumed workers if unable to determine the pool's worker count + _FALLBACK_SUBMISSION_WORKERS: ClassVar[int] = 4 + + +class BoundedThreadPoolExecutor(ThreadPoolExecutor, BoundedSubmissionMixin): + ... + + +class BoundedProcessPoolExecutor(ProcessPoolExecutor, BoundedSubmissionMixin): + ... + + +def result_or_default(future: Future[T], *, suppress: ExceptionGroupType, default: U) -> Union[T, U]: + """Gets the result from a future object, returning the default if a provided exception is raised + Intended for use in generator/list comprehensions since try/except blocks are not easily used in such comprehensions + + If the exception is not in the supress list, then the exception will be raised + """ + if isinstance(suppress, type) and issubclass(suppress, Exception): + suppress = (suppress,) + + try: + return future.result() + except suppress: + return default diff --git a/python/dapper_python/dataset_generation/utils/scraping.py b/python/dapper_python/dataset_generation/utils/scraping.py new file mode 100644 index 0000000..2b5b265 --- /dev/null +++ b/python/dapper_python/dataset_generation/utils/scraping.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import requests +import time + +from io import BytesIO +from http import HTTPStatus +from requests.exceptions import HTTPError, ConnectionError, Timeout +from tqdm.auto import tqdm + +from collections.abc import Mapping +from typing import Union, Optional + + +def download(url: str, *, + request_params: Optional[Mapping] = None, + progress_params: Union[Mapping, bool] = False) -> tuple[BytesIO, Optional[str]]: + """Utility function for downloading data with progress display + + :return: Tuple of (BytesIO, Content-Type) if content-type is not provided in response header, it will be None + """ + if request_params is None: + request_params = {} + request_params.pop("stream", None) + + with requests.get(url, **request_params, stream=True) as web_request: + web_request.raise_for_status() + if 'content-length' in web_request.headers: + file_size = int(web_request.headers['content-length']) + else: + file_size = None + + _progress_params = { + "total": file_size, + "desc": "Downloading file", + "unit": 'B', + "unit_divisor": 1024, + "unit_scale": True, + "position": None, + "leave": None, + } + if isinstance(progress_params, Mapping): + _progress_params.update(progress_params) + elif isinstance(progress_params, bool): + _progress_params["disable"] = not progress_params + + content = BytesIO() + with tqdm(**_progress_params) as progress_bar: + for chunk in web_request.iter_content(chunk_size=8 * 1024): + content.write(chunk) + progress_bar.update(len(chunk)) + + content.seek(0) + return content, web_request.headers.get('Content-Type', None) + + +def get_with_retry(url: str, *, retries: int = 5, **kwargs) -> requests.Response: + """Wrapper around requests.get with support for automatic retries on failure + + Attempts to retrieve the web content from the specified URL + Will retry if the request is not successful (i.e does not receive an HTTP OK status) + """ + for attempt in range(retries + 1): + try: + response = requests.get(url, **kwargs) + response.raise_for_status() + return response + + except (ConnectionError, Timeout): + if attempt >= retries: + raise + time.sleep(1) + + except HTTPError as e: + if attempt >= retries: + raise + + match e.response.status_code: + case HTTPStatus.NOT_FOUND: + raise + case HTTPStatus.TOO_MANY_REQUESTS: + retry_after = e.response.headers.get('Retry-After', 1) + time.sleep(retry_after) + case _: + time.sleep(1) + else: + # We should never reach this. All paths should either return or raise + # Only present for static type checking + raise RuntimeError diff --git a/python/dapper_python/utils/yet_more_itertools.py b/python/dapper_python/utils/yet_more_itertools.py new file mode 100644 index 0000000..5964ec6 --- /dev/null +++ b/python/dapper_python/utils/yet_more_itertools.py @@ -0,0 +1,51 @@ +""" +Provides some additional iterator functionality not present in itertools or more_itertools + +The modulename is a play on Python's itertools module, and another package more_itertools which adds additional iterator functionality +This adds even more functionality on top of those two +""" +from __future__ import annotations + +import itertools +import more_itertools + +from collections.abc import Iterable, Iterator, Generator +from typing import TypeVar +from typing import Optional + +T = TypeVar("T") + + +def chunked_iter(iterable: Iterable[T], chunk_size: Optional[int]) -> Generator[Iterator[T], None, None]: + """Splits an iterable into iterable chunks of size chunk_size + + Behaves very similarly to more_itertools.chunked, but is more memory efficient by operating only on iterators + If run on a generator/iterator, it does not load all entries into memory the way that more_itertools.chunked does + + Instead of returning a list of up to N items, it returns a generator that itself yields up to those N items + """ + # While this is fairly simple, it cannot be written directly into a comprehension, as comprehensions allow "for" loops but not "while" + it = more_itertools.peekable(iterable) + while it: + # Let islice handle validation of batch_size argument + yield itertools.islice(it, chunk_size) + + +def enforce_single_type(iterable: Iterable[T]) -> Generator[T, None, None]: + """Ensures all objects in an iterable are of the same type without consuming the iterable to check + + Takes an iterable of objects and yields them back, behaving mostly transparently as if it were the original iterator + However, if it encounters an object of a different type than ones that came before it, an error is raised + + Created as an inline way to ensure/enforce type matching without an operation like all(type(x) ... for x in iterable) + Which would consume the iterable to check all types, thus leaving it unusable afterward + """ + first, iterable = more_itertools.spy(iterable, n=1) + if not first: + return + obj_type = type(first[0]) + + for obj in iterable: + if type(obj) is not obj_type: + raise TypeError(f"Got different type {type(obj)}, expected {obj_type}") + yield obj diff --git a/python/pyproject.toml b/python/pyproject.toml index 162f4b7..dfcc482 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,17 +4,22 @@ build-backend = "setuptools.build_meta" [project] name = "dapper-python" -version = "0.0.0.dev3" +version = "0.0.0.dev4" description = "A Python package for interacting with DAPper datasets" authors = [ { name = "Ryan Mast", email = "mast9@llnl.gov" } ] license = { text = "MIT License" } readme = "README.md" -requires-python = ">=3.6" +requires-python = ">=3.9" dependencies = [ "tomlkit", - "typing-extensions>=4.6; python_version < '3.10'" + "more-itertools", + "typing-extensions", + #Database + "sqlmodel", + "pydantic", + "sqlalchemy", ] classifiers = [ "Programming Language :: Python :: 3", @@ -34,12 +39,14 @@ Discussions = "https://github.com/LLNL/dapper/discussions" "Source Code" = "https://github.com/LLNL/dapper" [project.optional-dependencies] -test = ["pytest"] +test = ["pytest", "tree-sitter", "tree-sitter-cpp"] dev = ["build", "pre-commit"] +dataset-generation = ["requests", "tqdm", "tree-sitter", "tree-sitter-cpp"] [dependency-groups] -test = ["pytest"] +test = ["pytest", "tree-sitter", "tree-sitter-cpp"] dev = ["build", "pre-commit"] +dataset-generation = ["requests", "tqdm", "tree-sitter", "tree-sitter-cpp"] [tool.setuptools.packages.find] include = ["dapper_python", "dapper_python.*"] diff --git a/python/tests/test_cpp_parsing.py b/python/tests/test_cpp_parsing.py new file mode 100644 index 0000000..066f68b --- /dev/null +++ b/python/tests/test_cpp_parsing.py @@ -0,0 +1,232 @@ +import pytest + +from typing import Optional + +from dapper_python.dataset_generation.parsing.cpp import CPPTreeParser +from dapper_python.dataset_generation.parsing.cpp import FunctionSymbol + + +@pytest.mark.parametrize( + "test_input,expected,expected_signature", + [ + # Simple int return type + pytest.param( + "int simple_int(int x){}", + FunctionSymbol( + return_type="int", + symbol_name="simple_int", + qualified_symbol_name="simple_int", + param_list=["int"], + ), + "int simple_int(int)", + ), + # Pointer return type, namespaced + pytest.param( + "double** math::submodule::get_buffer(size_t n){}", + FunctionSymbol( + return_type="double**", + symbol_name="get_buffer", + qualified_symbol_name="math::submodule::get_buffer", + param_list=["size_t"], + ), + "double** math::submodule::get_buffer(size_t)", + ), + # Reference return type, class scope + pytest.param( + "std::string& StringUtil::get_ref(std::string& s){}", + FunctionSymbol( + return_type="std::string&", + symbol_name="get_ref", + qualified_symbol_name="StringUtil::get_ref", + param_list=["std::string&"], + ), + "std::string& StringUtil::get_ref(std::string&)", + ), + # Rvalue reference return type, template parameter + pytest.param( + "std::vector&& VecUtil::move_vector(std::vector&& v){}", + FunctionSymbol( + return_type="std::vector&&", + symbol_name="move_vector", + qualified_symbol_name="VecUtil::move_vector", + param_list=["std::vector&&"], + ), + "std::vector&& VecUtil::move_vector(std::vector&&)", + ), + # Const reference return type, nested namespace + pytest.param( + "const std::map>& ns1::ns2::get_map(){}", + FunctionSymbol( + return_type="const std::map>&", + symbol_name="get_map", + qualified_symbol_name="ns1::ns2::get_map", + param_list=[], + ), + "const std::map>& ns1::ns2::get_map()", + ), + # Function pointer return type, pointer param (SKIPPED) + pytest.param( + "void (*CallbackUtil::get_callback())(int){}", + FunctionSymbol( + return_type="void (*)(int)", + symbol_name="get_callback", + qualified_symbol_name="CallbackUtil::get_callback", + param_list=[], + ), + "void (*)(int)", + marks=pytest.mark.skip(reason="Not currently supported"), + ), + # Returning vector of pointers, pointer param + pytest.param( + "std::vector PtrVec::make_vector(int* arr[], size_t n){}", + FunctionSymbol( + return_type="std::vector", + symbol_name="make_vector", + qualified_symbol_name="PtrVec::make_vector", + param_list=["int**", "size_t"], + ), + "std::vector PtrVec::make_vector(int**, size_t)", + ), + # Template type parameter, reference param + pytest.param( + "std::vector> DataUtil::process(const std::vector>& data){}", + FunctionSymbol( + return_type="std::vector>", + symbol_name="process", + qualified_symbol_name="DataUtil::process", + param_list=["const std::vector>&"], + ), + "std::vector> DataUtil::process(const std::vector>&)", + ), + # Pointer to vector param, returning int + pytest.param( + "int VecStat::sum(const std::vector< int >* vec){}", + FunctionSymbol( + return_type="int", + symbol_name="sum", + qualified_symbol_name="VecStat::sum", + param_list=["const std::vector*"], + ), + "int VecStat::sum(const std::vector*)", + ), + # Returning const pointer, pointer param + pytest.param( + "const char* StrUtil::find_char(const char* str, char c){}", + FunctionSymbol( + return_type="const char*", + symbol_name="find_char", + qualified_symbol_name="StrUtil::find_char", + param_list=["const char*", "char"], + ), + "const char* StrUtil::find_char(const char*, char)", + ), + # Returning volatile pointer, volatile pointer param + pytest.param( + "volatile int* VolUtil::get_volatile(volatile int* p){}", + FunctionSymbol( + return_type="volatile int*", + symbol_name="get_volatile", + qualified_symbol_name="VolUtil::get_volatile", + param_list=["volatile int*"], + ), + "volatile int* VolUtil::get_volatile(volatile int*)", + ), + # Returning restrict pointer, restrict pointer param (C only) + pytest.param( + "int* restrict RestrictUtil::restrict_op(int* restrict p){}", + FunctionSymbol( + return_type="int* restrict", + symbol_name="restrict_op", + qualified_symbol_name="RestrictUtil::restrict_op", + param_list=["int* restrict"], + ), + "int* restrict RestrictUtil::restrict_op(int* restrict)", + ), + # Returning std::vector, param is std::vector + pytest.param( + "std::vector VecUtil::int_to_string(const std::vector& v){}", + FunctionSymbol( + return_type="std::vector", + symbol_name="int_to_string", + qualified_symbol_name="VecUtil::int_to_string", + param_list=["const std::vector&"], + ), + "std::vector VecUtil::int_to_string(const std::vector&)", + ), + # Returning std::pair, param is std::map + pytest.param( + "std::pair MapUtil::find_pair(const std::map& m, int key){}", + FunctionSymbol( + return_type="std::pair", + symbol_name="find_pair", + qualified_symbol_name="MapUtil::find_pair", + param_list=["const std::map&", "int"], + ), + "std::pair MapUtil::find_pair(const std::map&, int)", + ), + # Argument with no identifier + pytest.param( + "void no_ident(int, bool){}", + FunctionSymbol( + return_type="void", + symbol_name="no_ident", + qualified_symbol_name="no_ident", + param_list=["int", "bool"], + ), + "void no_ident(int, bool)", + ), + # Const modifier on function + pytest.param( + "const std::string& error() const throw(){}", + FunctionSymbol( + return_type="const std::string&", + symbol_name="error", + qualified_symbol_name="error", + param_list=[], + modifiers=["const"], + ), + "const std::string& error() const", + ), + # Operator [] + pytest.param( + "const int& operator [](int index){}", + FunctionSymbol( + return_type="const int&", + symbol_name="operator []", + qualified_symbol_name="operator []", + param_list=["int"], + ), + "const int& operator [](int)", + ), + # Operator -> + pytest.param( + "myobj* operator ->(){}", + FunctionSymbol( + return_type="myobj*", + symbol_name="operator ->", + qualified_symbol_name="operator ->", + param_list=[], + ), + "myobj* operator ->()", + ), + # Namespace + Class + pytest.param( + "namespace foo{class bar{int baz(){return 1;}};}", + FunctionSymbol( + return_type="int", + symbol_name="baz", + qualified_symbol_name="foo::bar::baz", + param_list=[], + ), + "int foo::bar::baz()", + ), + ], +) +def test_function_parsing(test_input: str, expected: FunctionSymbol, expected_signature: Optional[str]): + tree = CPPTreeParser.from_source(test_input.encode()) + actual = list(tree.parse_functions()) + assert len(actual) == 1 + assert actual[0] == expected + + if expected_signature is not None: + assert actual[0].full_signature == expected_signature diff --git a/python/tests/test_database_v2.py b/python/tests/test_database_v2.py new file mode 100644 index 0000000..86c10ac --- /dev/null +++ b/python/tests/test_database_v2.py @@ -0,0 +1,215 @@ +import pytest + +import re +import random +import string +import warnings +import sqlite3 + +from pathlib import Path, PurePosixPath, PureWindowsPath +from sqlmodel import SQLModel, Field as SQLField, Column, Relationship +from sqlmodel import select + +from typing import Optional + +from dapper_python.databases_v2.database import BaseDatabase +from dapper_python.databases_v2.database import SQLPath + + +class UTDatabase(BaseDatabase): + + @property + def db_path(self) -> Path: + with self.session() as session: + db_uri = str(session.get_bind().url) + db_path = db_uri.removeprefix("sqlite:///") + return Path(db_path) + + +@UTDatabase.register_model +class UTModel1(SQLModel, table=True): + __tablename__ = "test_table_1" + + id: int = SQLField(primary_key=True) + value: str + + +@UTDatabase.register_model +class UTModel2(SQLModel, table=True): + __tablename__ = "test_table_2" + + id: int = SQLField(primary_key=True) + + # Relationships + T3: "UTModel3" = Relationship(back_populates="T2") + + +@UTDatabase.register_model +class UTModel3(SQLModel, table=True): + __tablename__ = "test_table_3" + id: Optional[int] = SQLField(default=None, primary_key=True, + foreign_key="test_table_2.id", ondelete="CASCADE") + value: str + + # Relationships + T2: "UTModel2" = Relationship(back_populates="T3") + + +@UTDatabase.register_model +class UTModel4(SQLModel, table=True): + __tablename__ = "test_table_4" + + id: int = SQLField(primary_key=True) + posix_path: PurePosixPath = SQLField(sa_column=Column(SQLPath(PurePosixPath))) + windows_path: PureWindowsPath = SQLField(sa_column=Column(SQLPath(PureWindowsPath))) + + +@pytest.fixture +def database(tmp_path): + database_path = tmp_path.joinpath("test_database.db") + return UTDatabase.create_database(database_path) + + +def generate_test_data(n: int = 1000, strlen: int = 20) -> dict[int, str]: + return { + i: "".join(random.choice(string.ascii_letters) for _ in range(strlen)) + for i in range(n) + } + + +def test_bulk_insert(database: UTDatabase): + test_data = generate_test_data() + + with database.session() as session: + data = ( + UTModel1( + id=key, + value=value, + ) + for key, value in test_data.items() + ) + with session.begin(): + session.bulk_insert(data, batch_size=95) + + # Check that the values match what we expect by accessing the database directly + with sqlite3.connect(database.db_path) as conn: + cursor = conn.cursor() + query = """ + SELECT id, value + FROM test_table_1 + WHERE id = ? + """ + for key, value in test_data.items(): + cursor.execute(query, (int(key),)) + d_key, d_value, *_ = cursor.fetchone() + + assert d_key == key + assert d_value == value + + +def test_bulk_insert_warnings(database: UTDatabase): + test_data = generate_test_data() + + with database.session() as session: + data = ( + UTModel2( + id=key, + T2=UTModel3( + value=value, + ), + ) + for key, value in test_data.items() + ) + + # We should get a warning when using bulk_insert to add a class that has relationships + with warnings.catch_warnings(record=True) as w: + with session.begin(): + session.bulk_insert(data, batch_size=100) + + assert len(w) == 1 + expected_message = "Class {cls} has relationships: bulk_insert will not insert them" + assert str(w[0].message) == expected_message.format(cls=UTModel2) + + +def test_bulk_delete(database: UTDatabase): + test_data = generate_test_data() + + with database.session() as session: + data = ( + UTModel2( + id=key, + T3=UTModel3( + value=value, + ), + ) + for key, value in test_data.items() + ) + with session.begin(): + session.add_all(data) + + with session.begin(): + to_remove = (( + entry + for entry in session.exec(select(UTModel2)) + if entry.id % 2 == 0 + )) + session.bulk_delete(to_remove) + + expected_data = { + key: value + for key, value in test_data.items() + if key % 2 != 0 + } + + with sqlite3.connect(database.db_path) as conn: + cursor = conn.cursor() + + # Ensure deletes worked properly + query = """ + SELECT id + FROM test_table_2 + ORDER BY id + """ + cursor.execute(query) + values = cursor.fetchall() + assert len(values) == len(test_data) // 2 + for e_id, (a_id, *_) in zip(expected_data.keys(), values): + assert e_id == a_id + + # Ensure cascading worked properly to remove values from associated table + query = """ + SELECT id, value + FROM test_table_3 + ORDER BY id + """ + cursor.execute(query) + values = cursor.fetchall() + for (a_id, a_value), (e_id, e_value, *_) in zip(expected_data.items(), values): + assert a_id == e_id + assert a_value == e_value + + +def test_path_loader(database: UTDatabase): + with database.session() as session: + # Add values to db + with session.begin(): + session.add(UTModel4( + id=1, + posix_path=PurePosixPath("/test/path/posix"), + windows_path=PureWindowsPath(r"C:\test\path\windows"), + )) + + # Check that the values are stored in the intended manner in the backend + with sqlite3.connect(database.db_path) as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM test_table_4 WHERE id = 1") + key, posix_path, windows_path, *_ = cursor.fetchone() + assert posix_path == r"/test/path/posix" + assert windows_path == r"C:/test/path/windows" + + # Check loading the value into the correct type + data = session.exec(select(UTModel4).where(UTModel4.id == 1)).first() + assert isinstance(data.posix_path, PurePosixPath) + assert data.posix_path == PurePosixPath("/test/path/posix") + assert isinstance(data.windows_path, PureWindowsPath) + assert data.windows_path == PureWindowsPath(r"C:\test\path\windows") diff --git a/python/tests/test_futures.py b/python/tests/test_futures.py new file mode 100644 index 0000000..68e07cb --- /dev/null +++ b/python/tests/test_futures.py @@ -0,0 +1,24 @@ +import pytest + +import functools + +from dapper_python.dataset_generation.utils.futures import BoundedThreadPoolExecutor + + +def test_bounded_executor(): + # Don't really have a good way to examine the backend as to what has been submitted when, + # but can at least make sure we get the intended results back out + def example_func(num: int) -> int: + return num * 2 + + with BoundedThreadPoolExecutor() as pool: + threads = (( + functools.partial(example_func, x) + for x in range(1000) + )) + results = pool.bounded_run(threads) + + results = sorted([x.result() for x in results]) + + expected = list(range(0, 2000, 2)) + assert results == expected diff --git a/python/tests/test_itertools.py b/python/tests/test_itertools.py new file mode 100644 index 0000000..d46dd81 --- /dev/null +++ b/python/tests/test_itertools.py @@ -0,0 +1,26 @@ +import pytest + +from dapper_python.utils.yet_more_itertools import chunked_iter, enforce_single_type + + +def test_chunked_iter(): + data = list(range(100)) + data_iter = (x for x in data) + + for i, chunk in enumerate(chunked_iter(data_iter, 10)): + expected = list(range(i * 10, i * 10 + 10)) + assert list(chunk) == list(expected) + + data_iter = (x for x in data) + for chunk in chunked_iter(data_iter, 23): + list(chunk) + +def test_enforce_single_type(): + data = [1, 2, 3, 4, 5] + data_iter = (x for x in data) + assert list(enforce_single_type(data_iter)) == data + + data = [1, 2, 3, 4.0, "5"] + data_iter = (x for x in data) + with pytest.raises(TypeError): + assert list(enforce_single_type(data_iter)) == data \ No newline at end of file