diff --git a/openhtf/util/validators.py b/openhtf/util/validators.py index 7c1d83648..d37cd7bad 100644 --- a/openhtf/util/validators.py +++ b/openhtf/util/validators.py @@ -441,6 +441,41 @@ def matches_regex(regex): return RegexMatcher(regex, re.compile(regex)) +class MultiRegexMatcher(ValidatorBase): + def __init__(self, regex_list: list[str], compiled_list: list[re.Pattern]) -> None: + self.regex_list = regex_list + self._compiled_list = compiled_list + + def __call__(self, candidate_str: str) -> bool: + for compiled_pattern in self._compiled_list: + if compiled_pattern.match(candidate_str): + return True + return False + + def __deepcopy__(self, dummy_memo): + return type(self)(self.regex_list[:], self._compiled_list[:]) + + def __str__(self): + patterns_str = " | ".join(self.regex_list) + return "'x' matches any of: /%s/" % patterns_str + + def __eq__(self, other): + return isinstance(other, type(self)) and self.regex_list == other.regex_list + + def __ne__(self, other) -> bool: + return not self == other + +@register +def matches_any_regex(*regex_collections: list[str]): + flat_regex_list = [] + for collection in regex_collections: + if not isinstance(collection, list) or not collection: + raise ValueError("Each argument must be a list of regex patterns.") + flat_regex_list.extend(collection) + + compiled_list = [re.compile(regex) for regex in flat_regex_list] + return MultiRegexMatcher(flat_regex_list, compiled_list) + class WithinPercent(RangeValidatorBase): """Validates that a number is within percent of a value.""" diff --git a/test/util/validators_test.py b/test/util/validators_test.py index f3060e9ca..5c4e65dce 100644 --- a/test/util/validators_test.py +++ b/test/util/validators_test.py @@ -78,6 +78,42 @@ def test_with_custom_type(self): self.assertEqual(test_validator.maximum, 0x12) +class TestSingleRegex(unittest.TestCase): + def test_single_regex(self): + pattern = r'^[A-Z]{3}\d{3}$' + validator = validators.matches_regex(pattern) + self.assertTrue(validator('ABC123')) + self.assertFalse(validator('abc123')) + self.assertFalse(validator('AB1234')) + self.assertFalse(validator('ABCD12')) + + +class TestMultipleRegex(unittest.TestCase): + patterns_1 = [r'^[A-Z]{1}\d{1}$', r'^[A-Z]{2}\d{2}$', r'^[A-Z]{3}\d{3}$'] + patterns_2 = [r'^\d{1}[A-Z]{1}$'] + def test_multiple_regex_lists(self): + + validator_1 = validators.matches_any_regex(TestMultipleRegex.patterns_1, TestMultipleRegex.patterns_2) + self.assertTrue(validator_1('ABC123')) + self.assertTrue(validator_1('A1')) + self.assertTrue(validator_1('1A')) + self.assertFalse(validator_1('123-ABCD')) + + def test_single_regex_list(self): + validator_2 = validators.matches_any_regex(TestMultipleRegex.patterns_2) + self.assertTrue(validator_2('1A')) + self.assertFalse(validator_2('A1')) + + def test_invalid_arguments(self): + with self.assertRaisesRegex(ValueError, "Each argument must be a list of regex patterns."): + validators.matches_any_regex(r'^[A-Z]{3}\d{3}$') + + with self.assertRaisesRegex(ValueError, "Each argument must be a list of regex patterns."): + validators.matches_any_regex([]) + + with self.assertRaisesRegex(ValueError, "Each argument must be a list of regex patterns."): + validators.matches_any_regex(r'd{3}', r'd{4}') + class TestAllInRange(unittest.TestCase): def setUp(self):