diff --git a/immermatch/app.py b/immermatch/app.py index c4b54b7..f8a8acf 100644 --- a/immermatch/app.py +++ b/immermatch/app.py @@ -52,6 +52,11 @@ profile_candidate, search_all_queries, ) +from immermatch.search_provider import ( # noqa: E402 + get_provider, + get_provider_fingerprint, + parse_provider_query, # noqa: E402 +) # --------------------------------------------------------------------------- # Page configuration @@ -726,21 +731,23 @@ def _run_pipeline() -> None: return cache = _get_cache() + provider = get_provider(location) + provider_fingerprint = get_provider_fingerprint(provider) # Eagerly create the Gemini client so it's ready before the pipeline starts — # avoids lazy-init delay between query generation and job evaluation. client = create_client() if _keys_ok() else None # ---- Step 1: Generate queries ---------------------------------------- with st.status("✨ Crafting search queries...", expanded=False) as status: - cached_queries = cache.load_queries(profile, location) + cached_queries = cache.load_queries(profile, location, provider_fingerprint) if cached_queries is not None: queries = cached_queries status.update(label="✅ Queries generated (cached)", state="complete") else: if client is None: client = create_client() - queries = generate_search_queries(client, profile, location) - cache.save_queries(profile, location, queries) + queries = generate_search_queries(client, profile, location, provider=provider) + cache.save_queries(profile, location, queries, provider_fingerprint) status.update(label="✅ Queries generated", state="complete") st.session_state.queries = queries @@ -775,7 +782,7 @@ def _run_pipeline() -> None: job = futures[future] evaluation = future.result() ej = EvaluatedJob(job=job, evaluation=evaluation) - key = f"{ej.job.title}|{ej.job.company_name}" + key = f"{ej.job.title}|{ej.job.company_name}|{ej.job.location}" all_evals[key] = ej progress_bar.progress( i / len(new_jobs), @@ -802,7 +809,7 @@ def _run_pipeline() -> None: def _on_jobs_found(new_unique_jobs: list[JobListing]) -> None: """Submit newly found jobs for evaluation immediately.""" for job in new_unique_jobs: - key = f"{job.title}|{job.company_name}" + key = f"{job.title}|{job.company_name}|{job.location}" if key in all_evals: continue # already evaluated (from cache) fut = eval_executor.submit(evaluate_job, client, profile, job) @@ -825,6 +832,7 @@ def _search_progress(qi: int, total: int, unique: int) -> None: location=location, on_progress=_search_progress, on_jobs_found=_on_jobs_found, + provider=provider, ) cache.save_jobs(jobs, location) search_status.update(label=f"✅ Found {len(jobs)} unique jobs", state="complete") @@ -847,7 +855,7 @@ def _search_progress(qi: int, total: int, unique: int) -> None: job = eval_futures[future] evaluation = future.result() ej = EvaluatedJob(job=job, evaluation=evaluation) - key = f"{ej.job.title}|{ej.job.company_name}" + key = f"{ej.job.title}|{ej.job.company_name}|{ej.job.location}" all_evals[key] = ej eval_progress.progress( i / total_evals, @@ -951,7 +959,8 @@ def _record_ip_rate_limit() -> None: expanded=False, ): for q in st.session_state.queries: - st.markdown(f"- {q}") + _, clean_query = parse_provider_query(q) + st.markdown(f"- {clean_query}") # -- Profile (collapsed) ----------------------------------------------- if st.session_state.profile is not None: diff --git a/immermatch/bundesagentur.py b/immermatch/bundesagentur.py index 8c9f7f6..57774e7 100644 --- a/immermatch/bundesagentur.py +++ b/immermatch/bundesagentur.py @@ -234,6 +234,7 @@ class BundesagenturProvider: """ name: str = "Bundesagentur für Arbeit" + source_id: str = "bundesagentur" def __init__( self, diff --git a/immermatch/cache.py b/immermatch/cache.py index f94c0a4..798ee3b 100644 --- a/immermatch/cache.py +++ b/immermatch/cache.py @@ -77,7 +77,12 @@ def save_profile(self, cv_text: str, profile: CandidateProfile) -> None: # 2. Queries (keyed by profile hash + location) # ------------------------------------------------------------------ - def load_queries(self, profile: CandidateProfile, location: str) -> list[str] | None: + def load_queries( + self, + profile: CandidateProfile, + location: str, + provider_fingerprint: str = "", + ) -> list[str] | None: data = self._load("queries.json") if data is None: return None @@ -85,17 +90,26 @@ def load_queries(self, profile: CandidateProfile, location: str) -> list[str] | return None if data.get("location") != location: return None + if data.get("provider_fingerprint", "") != provider_fingerprint: + return None queries = data.get("queries") if not isinstance(queries, list): return None return queries - def save_queries(self, profile: CandidateProfile, location: str, queries: list[str]) -> None: + def save_queries( + self, + profile: CandidateProfile, + location: str, + queries: list[str], + provider_fingerprint: str = "", + ) -> None: self._save( "queries.json", { "profile_hash": _profile_hash(profile), "location": location, + "provider_fingerprint": provider_fingerprint, "queries": queries, }, ) @@ -130,7 +144,7 @@ def save_jobs(self, jobs: list[JobListing], location: str = "") -> None: existing = data.get("jobs", {}) for job in jobs: - key = f"{job.title}|{job.company_name}" + key = f"{job.title}|{job.company_name}|{job.location}" existing[key] = job.model_dump() self._save( @@ -143,7 +157,7 @@ def save_jobs(self, jobs: list[JobListing], location: str = "") -> None: ) # ------------------------------------------------------------------ - # 4. Evaluations (append-only, keyed by title|company) + # 4. Evaluations (append-only, keyed by title|company|location) # ------------------------------------------------------------------ def load_evaluations(self, profile: CandidateProfile) -> dict[str, EvaluatedJob]: @@ -188,5 +202,5 @@ def get_unevaluated_jobs( Jobs already in the evaluation cache are skipped. """ cached = self.load_evaluations(profile) - new_jobs = [job for job in jobs if f"{job.title}|{job.company_name}" not in cached] + new_jobs = [job for job in jobs if f"{job.title}|{job.company_name}|{job.location}" not in cached] return new_jobs, cached diff --git a/immermatch/search_agent.py b/immermatch/search_agent.py index e233d9c..8471d68 100644 --- a/immermatch/search_agent.py +++ b/immermatch/search_agent.py @@ -7,6 +7,7 @@ from __future__ import annotations +import logging import threading from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor, as_completed @@ -16,7 +17,13 @@ from .llm import call_gemini, parse_json from .models import CandidateProfile, JobListing -from .search_provider import SearchProvider, get_provider +from .search_provider import ( + CombinedSearchProvider, + SearchProvider, + format_provider_query, + get_provider, + parse_provider_query, +) # Re-export SerpApi helpers so existing imports keep working. from .serpapi_provider import BLOCKED_PORTALS as _BLOCKED_PORTALS # noqa: F401 @@ -29,6 +36,25 @@ from .serpapi_provider import parse_job_results as _parse_job_results # noqa: F401 from .serpapi_provider import search_jobs # noqa: F401 +logger = logging.getLogger(__name__) +_MIN_JOBS_PER_PROVIDER = 30 + + +def _provider_quota_source_key(provider: SearchProvider) -> str: + """Return a stable source key for per-provider quota accounting.""" + source_id = getattr(provider, "source_id", None) + if isinstance(source_id, str) and source_id.strip(): + return source_id.strip().lower() + name = getattr(provider, "name", None) + if isinstance(name, str) and name == "Bundesagentur für Arbeit": + return "bundesagentur" + if isinstance(name, str) and "serpapi" in name.lower(): + return "serpapi" + if type(provider).__name__ == "SerpApiProvider": + return "serpapi" + return type(provider).__name__.lower() + + # System prompt for the Profiler agent PROFILER_SYSTEM_PROMPT = """You are an expert technical recruiter with deep knowledge of European job markets. You will be given the raw text of a candidate's CV. Extract a comprehensive profile. @@ -197,6 +223,49 @@ def generate_search_queries( if provider is None: provider = get_provider(location) + if isinstance(provider, CombinedSearchProvider): + provider_count = len(provider.providers) + if provider_count == 0: + return [] + + per_provider = num_queries // provider_count + remainder = num_queries % provider_count + merged_queries: list[str] = [] + + for index, child_provider in enumerate(provider.providers): + child_count = per_provider + (1 if index < remainder else 0) + if child_count <= 0: + continue + child_queries = _generate_search_queries_for_provider( + client, + profile, + location, + child_count, + child_provider, + ) + merged_queries.extend([format_provider_query(child_provider.name, query) for query in child_queries]) + + seen: set[str] = set() + unique_queries: list[str] = [] + for query in merged_queries: + if query in seen: + continue + seen.add(query) + unique_queries.append(query) + if len(unique_queries) >= num_queries: + break + return unique_queries + + return _generate_search_queries_for_provider(client, profile, location, num_queries, provider) + + +def _generate_search_queries_for_provider( + client: genai.Client, + profile: CandidateProfile, + location: str, + num_queries: int, + provider: SearchProvider, +) -> list[str]: # Select system prompt based on active provider if provider.name == "Bundesagentur für Arbeit": system_prompt = BA_HEADHUNTER_SYSTEM_PROMPT @@ -268,7 +337,14 @@ def search_all_queries( if provider is None: provider = get_provider(location) - all_jobs: dict[str, JobListing] = {} # Use title+company as key for dedup + quota_sources: set[str] = set() + if isinstance(provider, CombinedSearchProvider): + quota_sources = {_provider_quota_source_key(p) for p in provider.providers} + if quota_sources and min_unique_jobs > 0: + min_unique_jobs = max(min_unique_jobs, _MIN_JOBS_PER_PROVIDER * len(quota_sources)) + + all_jobs: dict[str, JobListing] = {} # Use title+company+location as key for dedup + source_counts: dict[str, int] = {} lock = threading.Lock() completed = 0 early_stop = threading.Event() @@ -276,22 +352,34 @@ def search_all_queries( def _search_one(query: str) -> list[JobListing]: if early_stop.is_set(): return [] - return provider.search(query, location, max_results=jobs_per_query) + clean_query = query + if not isinstance(provider, CombinedSearchProvider): + _, clean_query = parse_provider_query(query) + return provider.search(clean_query, location, max_results=jobs_per_query) with ThreadPoolExecutor(max_workers=min(5, max(1, len(queries)))) as executor: futures = [executor.submit(_search_one, q) for q in queries] for future in as_completed(futures): - jobs = future.result() + jobs: list[JobListing] = [] + try: + jobs = future.result() + except Exception: + logger.exception("A search query failed") batch_new: list[JobListing] = [] with lock: for job in jobs: - key = f"{job.title}|{job.company_name}" + key = f"{job.title}|{job.company_name}|{job.location}" if key not in all_jobs: all_jobs[key] = job batch_new.append(job) + source = (job.source or "unknown").lower() + source_counts[source] = source_counts.get(source, 0) + 1 completed += 1 progress_args = (completed, len(queries), len(all_jobs)) - if min_unique_jobs and len(all_jobs) >= min_unique_jobs: + quota_met = True + if quota_sources: + quota_met = all(source_counts.get(source, 0) >= _MIN_JOBS_PER_PROVIDER for source in quota_sources) + if min_unique_jobs and len(all_jobs) >= min_unique_jobs and quota_met: early_stop.set() # Callbacks outside the lock to avoid blocking other threads if on_progress is not None: @@ -305,4 +393,19 @@ def _search_one(query: str) -> list[JobListing]: f.cancel() break + if source_counts: + counts_text = ", ".join(f"{source}={count}" for source, count in sorted(source_counts.items())) + logger.info("Search source counts for location '%s': %s", location or "(none)", counts_text) + if quota_sources: + missing = [ + source for source in sorted(quota_sources) if source_counts.get(source, 0) < _MIN_JOBS_PER_PROVIDER + ] + if missing: + logger.warning( + "Provider quota not reached for location '%s': %s (required >= %d each)", + location or "(none)", + ", ".join(missing), + _MIN_JOBS_PER_PROVIDER, + ) + return list(all_jobs.values()) diff --git a/immermatch/search_provider.py b/immermatch/search_provider.py index 4fa6745..0c81c8e 100644 --- a/immermatch/search_provider.py +++ b/immermatch/search_provider.py @@ -8,12 +8,59 @@ from __future__ import annotations import logging +import math +import os from typing import Protocol, runtime_checkable from .models import JobListing logger = logging.getLogger(__name__) +_PROVIDER_QUERY_PREFIX = "provider=" +_PROVIDER_QUERY_SEPARATOR = "::" + + +def format_provider_query(provider_name: str, query: str) -> str: + """Format a query with explicit provider routing metadata.""" + return f"{_PROVIDER_QUERY_PREFIX}{provider_name}{_PROVIDER_QUERY_SEPARATOR}{query}" + + +def parse_provider_query(query: str) -> tuple[str | None, str]: + """Parse an optionally provider-targeted query. + + Query format: + provider=:: + + Returns: + (target_provider_name, clean_query) + """ + if query.startswith(_PROVIDER_QUERY_PREFIX) and _PROVIDER_QUERY_SEPARATOR in query: + meta, clean_query = query.split(_PROVIDER_QUERY_SEPARATOR, 1) + target_provider = meta.removeprefix(_PROVIDER_QUERY_PREFIX).strip() + if target_provider and clean_query.strip(): + return target_provider, clean_query.strip() + return None, query + + +def get_provider_fingerprint(provider: SearchProvider) -> str: + """Return a stable fingerprint for the active provider configuration. + + Used by query cache to avoid reusing provider-targeted query sets when + provider configuration changes (e.g. SerpApi enabled/disabled). + """ + + def _provider_key(p: SearchProvider) -> str: + source_id = getattr(p, "source_id", None) + if isinstance(source_id, str) and source_id.strip(): + return source_id.strip().lower() + name = getattr(p, "name", "") + if isinstance(name, str) and name.strip(): + return name.strip().lower() + return type(p).__name__.lower() + + providers = provider.providers if isinstance(provider, CombinedSearchProvider) else [provider] + return "|".join(sorted({_provider_key(p) for p in providers})) + @runtime_checkable class SearchProvider(Protocol): @@ -45,16 +92,72 @@ def search( ... +class CombinedSearchProvider: + """Run multiple providers for each query and merge their results.""" + + name: str = "Bundesagentur + SerpApi" + + def __init__(self, providers: list[SearchProvider]) -> None: + self.providers = providers + + def search( + self, + query: str, + location: str, + max_results: int = 50, + ) -> list[JobListing]: + if not self.providers: + return [] + + target_provider, clean_query = parse_provider_query(query) + providers = self.providers + if target_provider is not None: + providers = [provider for provider in self.providers if provider.name == target_provider] + if not providers: + logger.warning( + "Unknown targeted provider '%s' in query, falling back to all providers", target_provider + ) + providers = self.providers + + if not providers: + return [] + + if max_results <= 0: + return [] + + merged: dict[str, JobListing] = {} + per_provider = max(1, math.ceil(max_results / len(providers))) + for provider in providers: + try: + jobs = provider.search(clean_query, location, max_results=per_provider) + except Exception: + logger.exception("Provider '%s' failed for query '%s'", provider.name, clean_query) + continue + + for job in jobs: + key = f"{job.title}|{job.company_name}|{job.location}" + if key not in merged: + merged[key] = job + + return list(merged.values())[:max_results] + + def get_provider(location: str = "") -> SearchProvider: # noqa: ARG001 """Return the appropriate ``SearchProvider`` for *location*. - Currently always returns the Bundesagentur für Arbeit provider - (Germany-only). This factory is the single extension point for - future per-country routing — e.g. returning ``SerpApiProvider`` - for non-German locations. + Returns a combined provider that merges Bundesagentur and SerpApi + results when ``SERPAPI_KEY`` is available. If SerpApi is not + configured, falls back to Bundesagentur only. """ # Lazy import so the module can be loaded without pulling in httpx # when only the protocol is needed (e.g. for type-checking). from .bundesagentur import BundesagenturProvider # noqa: PLC0415 + from .serpapi_provider import SerpApiProvider # noqa: PLC0415 + + providers: list[SearchProvider] = [BundesagenturProvider()] + if os.getenv("SERPAPI_KEY"): + providers.append(SerpApiProvider()) - return BundesagenturProvider() + if len(providers) == 1: + return providers[0] + return CombinedSearchProvider(providers) diff --git a/immermatch/serpapi_provider.py b/immermatch/serpapi_provider.py index cf9bdbf..7a21a0a 100644 --- a/immermatch/serpapi_provider.py +++ b/immermatch/serpapi_provider.py @@ -348,6 +348,7 @@ class SerpApiProvider: """ name: str = "SerpApi (Google Jobs)" + source_id: str = "serpapi" def search( self, diff --git a/tests/test_cache.py b/tests/test_cache.py index 87eccde..06a4223 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -70,6 +70,10 @@ def test_miss_on_different_profile(self, cache: ResultCache, profile: CandidateP ) assert cache.load_queries(other, "Munich") is None + def test_miss_on_different_provider_fingerprint(self, cache: ResultCache, profile: CandidateProfile): + cache.save_queries(profile, "Munich", ["q1"], provider_fingerprint="bundesagentur|serpapi") + assert cache.load_queries(profile, "Munich", provider_fingerprint="bundesagentur") is None + class TestJobsCache: @freeze_time("2026-02-20") @@ -132,17 +136,17 @@ class TestEvaluationsCache: def test_round_trip(self, cache: ResultCache, profile: CandidateProfile): job = JobListing(title="Dev", company_name="Corp", location="Berlin") ev = JobEvaluation(score=80, reasoning="Good match.") - evaluated = {"Dev|Corp": EvaluatedJob(job=job, evaluation=ev)} + evaluated = {"Dev|Corp|Berlin": EvaluatedJob(job=job, evaluation=ev)} cache.save_evaluations(profile, evaluated) loaded = cache.load_evaluations(profile) - assert "Dev|Corp" in loaded - assert loaded["Dev|Corp"].evaluation.score == 80 + assert "Dev|Corp|Berlin" in loaded + assert loaded["Dev|Corp|Berlin"].evaluation.score == 80 def test_miss_on_different_profile(self, cache: ResultCache, profile: CandidateProfile): job = JobListing(title="Dev", company_name="Corp", location="Berlin") ev = JobEvaluation(score=80, reasoning="Good match.") - cache.save_evaluations(profile, {"Dev|Corp": EvaluatedJob(job=job, evaluation=ev)}) + cache.save_evaluations(profile, {"Dev|Corp|Berlin": EvaluatedJob(job=job, evaluation=ev)}) other = CandidateProfile( skills=["Java"], @@ -159,9 +163,9 @@ def test_filters_already_evaluated(self, cache: ResultCache, profile: CandidateP job1 = JobListing(title="Dev", company_name="Corp", location="Berlin") job2 = JobListing(title="PM", company_name="Corp", location="Berlin") ev = JobEvaluation(score=80, reasoning="Good.") - cache.save_evaluations(profile, {"Dev|Corp": EvaluatedJob(job=job1, evaluation=ev)}) + cache.save_evaluations(profile, {"Dev|Corp|Berlin": EvaluatedJob(job=job1, evaluation=ev)}) new_jobs, cached = cache.get_unevaluated_jobs([job1, job2], profile) assert len(new_jobs) == 1 assert new_jobs[0].title == "PM" - assert "Dev|Corp" in cached + assert "Dev|Corp|Berlin" in cached diff --git a/tests/test_integration.py b/tests/test_integration.py index 8108e1b..bf6e0a6 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -322,6 +322,12 @@ def mock_client() -> MagicMock: return MagicMock() +def _query_provider() -> MagicMock: + provider = MagicMock() + provider.name = "Bundesagentur für Arbeit" + return provider + + # --------------------------------------------------------------------------- # Integration tests # --------------------------------------------------------------------------- @@ -368,7 +374,7 @@ def test_full_pipeline_happy_path( assert len(profile.work_history) == 2 # --- Act: Stage 2 — Queries --- - queries = generate_search_queries(mock_client, profile, "Munich, Germany") + queries = generate_search_queries(mock_client, profile, "Munich, Germany", provider=_query_provider()) assert isinstance(queries, list) assert len(queries) == 20 @@ -452,7 +458,7 @@ def test_full_pipeline_non_tech_cv( assert profile.experience_level == "Mid" assert any("Sustainability" in r for r in profile.roles) - queries = generate_search_queries(mock_client, profile, "Munich, Germany") + queries = generate_search_queries(mock_client, profile, "Munich, Germany", provider=_query_provider()) assert len(queries) == 20 jobs = search_all_queries( @@ -534,7 +540,7 @@ def test_queries_are_strings_and_correct_count( mock_gemini.side_effect = [TECH_PROFILE_JSON, QUERIES_JSON] profile = profile_candidate(mock_client, tech_cv_text) - queries = generate_search_queries(mock_client, profile, "Munich, Germany") + queries = generate_search_queries(mock_client, profile, "Munich, Germany", provider=_query_provider()) assert len(queries) == 20 assert all(isinstance(q, str) for q in queries) @@ -689,7 +695,7 @@ def test_empty_search_produces_empty_evaluations( mock_provider.search.return_value = [] profile = profile_candidate(mock_client, tech_cv_text) - queries = generate_search_queries(mock_client, profile, "Munich, Germany") + queries = generate_search_queries(mock_client, profile, "Munich, Germany", provider=_query_provider()) jobs = search_all_queries( queries, jobs_per_query=10, @@ -739,7 +745,7 @@ def test_cv_data_flows_through_all_stages( assert "TechCorp" in profile_prompt or "John Doe" in profile_prompt # Stage 2: Queries — verify profile data was sent to Gemini - queries = generate_search_queries(mock_client, profile, "Munich, Germany") + queries = generate_search_queries(mock_client, profile, "Munich, Germany", provider=_query_provider()) query_prompt = mock_search_gemini.call_args_list[1][0][1] assert "Senior Software Engineer" in query_prompt # from profile.roles assert "Python" in query_prompt # from profile.skills diff --git a/tests/test_search_agent.py b/tests/test_search_agent.py index ba047af..085c7e4 100644 --- a/tests/test_search_agent.py +++ b/tests/test_search_agent.py @@ -1,6 +1,7 @@ """Tests for immermatch.search_agent — pure helper functions and search_all_queries orchestration.""" import json +from typing import ClassVar from unittest.mock import MagicMock, patch import pytest @@ -11,10 +12,12 @@ _is_remote_only, _localise_query, _parse_job_results, + _provider_quota_source_key, generate_search_queries, profile_candidate, search_all_queries, ) +from immermatch.search_provider import CombinedSearchProvider class TestIsRemoteOnly: @@ -178,11 +181,11 @@ def test_highlights_in_description(self): class TestSearchAllQueries: """Tests for search_all_queries() — mock provider to test orchestration logic.""" - def _make_job(self, title: str, company: str = "Co") -> JobListing: + def _make_job(self, title: str, company: str = "Co", location: str = "Berlin") -> JobListing: return JobListing( title=title, company_name=company, - location="Berlin", + location=location, apply_options=[ApplyOption(source="LinkedIn", url="https://linkedin.com/1")], ) @@ -208,8 +211,14 @@ def test_passes_query_and_location_to_provider(self): max_results=10, ) - def test_deduplicates_by_title_and_company(self): - provider = self._make_provider([self._make_job("Dev"), self._make_job("Dev")]) + def test_deduplicates_by_title_company_and_location(self): + provider = self._make_provider( + [ + self._make_job("Dev", location="Berlin"), + self._make_job("Dev", location="Berlin"), + self._make_job("Dev", location="Munich"), + ] + ) results = search_all_queries( queries=["query1", "query2"], @@ -218,7 +227,7 @@ def test_deduplicates_by_title_and_company(self): provider=provider, ) - assert len(results) == 1 + assert len(results) == 2 def test_stops_early_when_min_unique_jobs_reached(self): provider = self._make_provider([self._make_job("Unique Job")]) @@ -276,6 +285,130 @@ def test_defaults_to_get_provider(self, mock_gp: MagicMock): mock_gp.assert_called_once_with("Berlin") + def test_combined_provider_hard_quota_requires_30_each_before_stop(self): + ba_provider = MagicMock() + ba_provider.name = "Bundesagentur für Arbeit" + ba_provider.source_id = "bundesagentur" + ba_jobs = [self._make_job(f"BA {i}", company=f"BA Co {i}", location="Berlin") for i in range(30)] + for job in ba_jobs: + job.source = "bundesagentur" + ba_provider.search.return_value = ba_jobs + + serp_provider = MagicMock() + serp_provider.name = "SerpApi (Google Jobs)" + serp_provider.source_id = "serpapi" + serp_jobs = [self._make_job(f"SERP {i}", company=f"SERP Co {i}", location="Berlin") for i in range(30)] + for job in serp_jobs: + job.source = "serpapi" + serp_provider.search.return_value = serp_jobs + + combined = CombinedSearchProvider([ba_provider, serp_provider]) + results = search_all_queries( + queries=[ + "provider=Bundesagentur für Arbeit::Softwareentwickler", + "provider=SerpApi (Google Jobs)::Python Developer Berlin", + ], + jobs_per_query=30, + location="Berlin", + min_unique_jobs=50, + provider=combined, + ) + + assert len(results) == 60 + ba_count = len([job for job in results if job.source == "bundesagentur"]) + serp_count = len([job for job in results if job.source == "serpapi"]) + assert ba_count >= 30 + assert serp_count >= 30 + + @patch("immermatch.search_agent.logger") + def test_logs_source_counts(self, mock_logger: MagicMock): + provider = self._make_provider( + [ + self._make_job("BA Job", location="Berlin"), + self._make_job("SERP Job", location="Munich"), + ] + ) + provider.search.return_value[0].source = "bundesagentur" + provider.search.return_value[1].source = "serpapi" + + search_all_queries( + queries=["query1"], + location="Berlin", + min_unique_jobs=0, + provider=provider, + ) + + assert mock_logger.info.called + logged_texts = " ".join(str(call.args) for call in mock_logger.info.call_args_list) + assert "bundesagentur" in logged_texts + assert "serpapi" in logged_texts + + def test_combined_provider_routes_query_to_target_provider(self): + ba_provider = MagicMock() + ba_provider.name = "Bundesagentur für Arbeit" + ba_provider.search.return_value = [] + + serp_provider = MagicMock() + serp_provider.name = "SerpApi (Google Jobs)" + serp_provider.search.return_value = [self._make_job("Dev", location="Berlin")] + + combined = CombinedSearchProvider([ba_provider, serp_provider]) + search_all_queries( + queries=["provider=SerpApi (Google Jobs)::Python Developer Berlin"], + location="Berlin", + min_unique_jobs=0, + provider=combined, + ) + + ba_provider.search.assert_not_called() + serp_provider.search.assert_called_once_with("Python Developer Berlin", "Berlin", max_results=10) + + def test_provider_quota_source_key_prefers_source_id(self): + class ThirdProvider: + name: ClassVar[str] = "Third Provider" + source_id: ClassVar[str] = "third-source" + + ba_provider = MagicMock() + ba_provider.name = "Bundesagentur für Arbeit" + + serp_provider = MagicMock() + serp_provider.name = "SerpApi (Google Jobs)" + serp_provider.source_id = "serpapi" + + third_provider = ThirdProvider() + + assert _provider_quota_source_key(ba_provider) == "bundesagentur" + assert _provider_quota_source_key(serp_provider) == "serpapi" + assert _provider_quota_source_key(third_provider) == "third-source" + + def test_min_unique_zero_does_not_enable_combined_quota(self): + ba_provider = MagicMock() + ba_provider.name = "Bundesagentur für Arbeit" + ba_provider.source_id = "bundesagentur" + ba_jobs = [self._make_job(f"BA {i}", company=f"BA Co {i}", location="Berlin") for i in range(10)] + for job in ba_jobs: + job.source = "bundesagentur" + ba_provider.search.return_value = ba_jobs + + serp_provider = MagicMock() + serp_provider.name = "SerpApi (Google Jobs)" + serp_provider.source_id = "serpapi" + serp_jobs = [self._make_job(f"SERP {i}", company=f"SERP Co {i}", location="Berlin") for i in range(10)] + for job in serp_jobs: + job.source = "serpapi" + serp_provider.search.return_value = serp_jobs + + combined = CombinedSearchProvider([ba_provider, serp_provider]) + results = search_all_queries( + queries=["q1", "q2"], + jobs_per_query=10, + location="Berlin", + min_unique_jobs=0, + provider=combined, + ) + + assert len(results) == 10 + class TestLlmJsonRecovery: @patch("immermatch.search_agent.call_gemini") @@ -325,8 +458,16 @@ def test_generate_search_queries_retries_after_invalid_json(self, mock_call_gemi education_history=[], ) mock_call_gemini.side_effect = ["not json", '["python developer berlin", "backend berlin"]'] + provider = MagicMock() + provider.name = "SerpApi (Google Jobs)" - queries = generate_search_queries(MagicMock(), profile, location="Berlin, Germany", num_queries=2) + queries = generate_search_queries( + MagicMock(), + profile, + location="Berlin, Germany", + num_queries=2, + provider=provider, + ) assert queries == ["python developer berlin", "backend berlin"] assert mock_call_gemini.call_count == 2 @@ -356,8 +497,16 @@ def test_generate_search_queries_returns_empty_list_after_all_retries_fail(self, education_history=[], ) mock_call_gemini.side_effect = ["not json", "still not json"] + provider = MagicMock() + provider.name = "SerpApi (Google Jobs)" - queries = generate_search_queries(MagicMock(), profile, location="Berlin, Germany", num_queries=2) + queries = generate_search_queries( + MagicMock(), + profile, + location="Berlin, Germany", + num_queries=2, + provider=provider, + ) assert queries == [] assert mock_call_gemini.call_count == 2 @@ -480,3 +629,30 @@ def test_other_provider_uses_default_prompt(self, mock_call_gemini: MagicMock): prompt_sent = mock_call_gemini.call_args[0][1] assert "Google Jobs" in prompt_sent assert "LOCAL names" in prompt_sent + + @patch("immermatch.search_agent.call_gemini") + def test_combined_provider_generates_queries_per_child_provider(self, mock_call_gemini: MagicMock): + mock_call_gemini.side_effect = [ + '["Softwareentwickler", "Datenanalyst"]', + '["Python Developer Berlin", "Data Engineer Berlin"]', + ] + + ba_provider = MagicMock() + ba_provider.name = "Bundesagentur für Arbeit" + serp_provider = MagicMock() + serp_provider.name = "SerpApi (Google Jobs)" + combined = CombinedSearchProvider([ba_provider, serp_provider]) + + queries = generate_search_queries( + MagicMock(), + self._PROFILE, + location="Berlin", + num_queries=4, + provider=combined, + ) + + assert len(queries) == 4 + assert all(query.startswith("provider=") for query in queries) + prompts_sent = [call.args[1] for call in mock_call_gemini.call_args_list] + assert any("Bundesagentur" in prompt for prompt in prompts_sent) + assert any("Google Jobs" in prompt for prompt in prompts_sent) diff --git a/tests/test_search_provider.py b/tests/test_search_provider.py new file mode 100644 index 0000000..f17e9f2 --- /dev/null +++ b/tests/test_search_provider.py @@ -0,0 +1,56 @@ +"""Tests for search provider helpers and combined provider behavior.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +from immermatch.models import ApplyOption, JobListing +from immermatch.search_provider import CombinedSearchProvider, parse_provider_query + + +def _make_job(title: str, company: str, location: str = "Berlin") -> JobListing: + return JobListing( + title=title, + company_name=company, + location=location, + apply_options=[ApplyOption(source="Company Website", url="https://example.com")], + ) + + +class TestParseProviderQuery: + def test_parses_targeted_query(self): + target, query = parse_provider_query("provider=SerpApi (Google Jobs)::Python Developer Berlin") + assert target == "SerpApi (Google Jobs)" + assert query == "Python Developer Berlin" + + def test_returns_original_when_not_targeted(self): + target, query = parse_provider_query("Softwareentwickler") + assert target is None + assert query == "Softwareentwickler" + + +class TestCombinedSearchProvider: + def test_splits_max_results_budget_across_providers(self): + p1 = MagicMock() + p1.name = "Bundesagentur für Arbeit" + p1.search.return_value = [_make_job(f"BA {i}", f"BA Co {i}") for i in range(3)] + + p2 = MagicMock() + p2.name = "SerpApi (Google Jobs)" + p2.search.return_value = [_make_job(f"SERP {i}", f"SERP Co {i}") for i in range(3)] + + provider = CombinedSearchProvider([p1, p2]) + results = provider.search("Developer", "Berlin", max_results=5) + + p1.search.assert_called_once_with("Developer", "Berlin", max_results=3) + p2.search.assert_called_once_with("Developer", "Berlin", max_results=3) + assert len(results) == 5 + + def test_returns_empty_when_max_results_non_positive(self): + p1 = MagicMock() + p1.name = "Bundesagentur für Arbeit" + p1.search.return_value = [_make_job("BA", "BA Co")] + + provider = CombinedSearchProvider([p1]) + assert provider.search("Developer", "Berlin", max_results=0) == [] + p1.search.assert_not_called()