{
apiBasePath={config.apiBasePath}
/>
- {notInIframe && (
+ {!inIFrame && (
}
- onClick={() => navigate('/logout')}
+ onClick={() => {
+ navigate(`/logout`);
+ }}
style={{ marginTop: 0 }}
>
Logout
@@ -151,4 +189,4 @@ export const MainPage = () => {
);
};
-export default MainPage;
+export default DemoPage;
diff --git a/client/src/pages/LandingPage.tsx b/client/src/pages/LandingPage.tsx
new file mode 100644
index 0000000..1c59ccb
--- /dev/null
+++ b/client/src/pages/LandingPage.tsx
@@ -0,0 +1,109 @@
+import { useNavigate } from 'react-router-dom';
+import { Button, Typography, Card, Row, Col, Layout, Spin } from 'antd';
+import { CodeOutlined, PictureOutlined } from '@ant-design/icons';
+import { useToken } from '../hooks/useToken';
+import { useConfig } from '../hooks/useConfig';
+
+const { Title, Paragraph } = Typography;
+const { Content } = Layout;
+
+const LandingPage = () => {
+ const navigate = useNavigate();
+ const config = useConfig();
+ const { data: tokenData, isLoading: tokenLoading } = useToken();
+
+ const handleLogin = () => {
+ sessionStorage.setItem('oauth_return_to', '/');
+
+ const state = Math.random().toString(36).substring(7);
+ sessionStorage.setItem('oauth_state', state);
+
+ const authUrl = new URL('https://designsafe.tapis.io/v3/oauth2/authorize');
+ authUrl.searchParams.append('client_id', config.clientId);
+ authUrl.searchParams.append(
+ 'redirect_uri',
+ `${window.location.origin}/imageinf/ui/auth/callback/`
+ );
+ authUrl.searchParams.append('response_type', 'token');
+
+ window.location.href = authUrl.toString();
+ };
+
+ if (tokenLoading) {
+ return (
+
+
+
+
+
+ );
+ }
+
+ return (
+
+
+ imageInf
+
+ Image inferencing service.
+
+
+ {tokenData?.isValid ? (
+
+
+ navigate('/demo')}
+ >
+
+
+
+ Developer Demo
+
+
+ Explore API and image inference.
+
+
+
+
+
+
+ navigate('/classify')}
+ >
+
+
+
+ Gallery Classifier
+
+
+ Classify curated image sets.
+
+
+
+
+
+
+ ) : (
+
+
+
+ )}
+
+
+ );
+};
+
+export default LandingPage;
diff --git a/client/src/pages/Login.tsx b/client/src/pages/Login.tsx
index 5b4f6c9..6a92bc8 100644
--- a/client/src/pages/Login.tsx
+++ b/client/src/pages/Login.tsx
@@ -1,12 +1,21 @@
-import { Button } from 'antd';
+import { useEffect } from 'react';
+import { useSearchParams } from 'react-router-dom';
+import { Spin, Layout } from 'antd';
import { useConfig } from '../hooks/useConfig';
+const { Content } = Layout;
+
const Login = () => {
const config = useConfig();
+ const [searchParams] = useSearchParams();
+
+ useEffect(() => {
+ // Store where to return after successful auth
+ const returnTo = searchParams.get('returnTo') || '/';
+ sessionStorage.setItem('oauth_return_to', returnTo);
- const handleLogin = () => {
// Generate a random state parameter for security
- // a store state in sessionStorage for verification
+ // and store state in sessionStorage for verification
const state = Math.random().toString(36).substring(7);
sessionStorage.setItem('oauth_state', state);
@@ -19,31 +28,19 @@ const Login = () => {
);
authUrl.searchParams.append('response_type', 'token');
- // TODO tapis not supporting at the moment
-
+ // TODO tapis not supporting state parameter at the moment
//authUrl.searchParams.append('state', state);
// Redirect to OAuth provider
- window.location.href = authUrl.toString();
- };
+ window.location.replace(authUrl.toString());
+ }, [config.clientId, searchParams]);
return (
-
-
Image Inferencing Service Login
-
-
-
-
+
+
+
+
+
);
};
diff --git a/client/src/pages/Logout.tsx b/client/src/pages/Logout.tsx
index ffdc6e9..46f48b9 100644
--- a/client/src/pages/Logout.tsx
+++ b/client/src/pages/Logout.tsx
@@ -6,11 +6,9 @@ const Logout = () => {
const navigate = useNavigate();
useEffect(() => {
- // Clear all session storage
sessionStorage.clear();
-
- // Redirect to login
- navigate('/login');
+ // go to landing page
+ navigate('/', { replace: true });
}, [navigate]);
return (
diff --git a/client/src/types/inference.ts b/client/src/types/inference.ts
index f7953b4..0a288fe 100644
--- a/client/src/types/inference.ts
+++ b/client/src/types/inference.ts
@@ -12,20 +12,34 @@ export interface InferenceResult {
systemId: string;
path: string;
predictions: Prediction[];
+ metadata?: ImageMetadata | null;
+}
+
+export interface ImageMetadata {
+ date_taken?: string;
+ latitude?: number | null;
+ longitude?: number | null;
+ altitude?: number | null;
+ camera_make?: string;
+ camera_model?: string;
}
export interface InferenceRequest {
files: TapisFile[];
model?: string;
+ labels?: string[]; // Note: CLIP only
+ sensitivity?: 'high' | 'medium' | 'low'; // Note: CLIP only
}
export interface InferenceResponse {
model: string;
+ aggregated_results: InferenceResult[];
results: InferenceResult[];
}
export interface InferenceModelMeta {
name: string;
+ type: string;
description: string;
link: string;
}
diff --git a/imageinf/inference/clip_base.py b/imageinf/inference/clip_base.py
index 95aaeea..39ca657 100644
--- a/imageinf/inference/clip_base.py
+++ b/imageinf/inference/clip_base.py
@@ -36,6 +36,12 @@ class BaseCLIPModel:
"sky",
]
+ SENSITIVITY_PRESETS = {
+ "high": {"threshold": 0.45, "temperature": 15.0},
+ "medium": {"threshold": 0.55, "temperature": 20.0},
+ "low": {"threshold": 0.65, "temperature": 25.0},
+ }
+
DEFAULT_THRESHOLD = 0.55
BINARY_TEMPERATURE = 20.0
@@ -75,12 +81,16 @@ def _precompute_text_features(self):
def classify_image(
self,
image: Image.Image,
- threshold: Optional[float] = None,
- top_k: Optional[int] = None,
+ sensitivity: str = "medium",
debug_when_empty: bool = True,
) -> List[Prediction]:
- if threshold is None:
- threshold = self.DEFAULT_THRESHOLD
+
+ # Get threshold and temperature from sensitivity preset
+ preset = self.SENSITIVITY_PRESETS.get(
+ sensitivity, self.SENSITIVITY_PRESETS["medium"]
+ )
+ threshold = preset["threshold"]
+ temperature = preset["temperature"]
if image.mode != "RGB":
image = image.convert("RGB")
@@ -95,7 +105,7 @@ def classify_image(
img_feat = F.normalize(img_feat, dim=-1)
sims2 = torch.einsum("bd,lcd->blc", img_feat, self.text_pairs)
- logits2 = sims2 * self.BINARY_TEMPERATURE
+ logits2 = sims2 * temperature
probs2 = torch.softmax(logits2, dim=-1)[0]
presence = probs2[:, 0]
@@ -107,8 +117,6 @@ def classify_image(
preds_all.sort(key=lambda p: p.score, reverse=True)
preds = [p for p in preds_all if p.score >= threshold]
- if top_k is not None:
- preds = preds[:top_k]
if not preds and debug_when_empty:
top_dbg = preds_all[:5]
diff --git a/imageinf/inference/clip_models.py b/imageinf/inference/clip_models.py
index b24cf31..2fae273 100644
--- a/imageinf/inference/clip_models.py
+++ b/imageinf/inference/clip_models.py
@@ -4,6 +4,7 @@
@register_model_runner(
"openai/clip-vit-large-patch14",
+ model_type="clip",
description="CLIP ViT-Large - zero-shot multi-label (~400M params)",
link="https://huggingface.co/openai/clip-vit-large-patch14",
)
@@ -15,6 +16,7 @@ class CLIPViTLarge(BaseCLIPModel):
@register_model_runner(
"wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M",
+ model_type="clip",
description="TinyCLIP - efficient zero-shot classifier (~59M params total)",
link="https://huggingface.co/wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M",
)
@@ -27,6 +29,7 @@ class TinyCLIP(BaseCLIPModel):
@register_model_runner(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
+ model_type="clip",
description="CLIP ViT-Huge - highest accuracy zero-shot (~1B params)",
link="https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
)
diff --git a/imageinf/inference/models.py b/imageinf/inference/models.py
index bd50b5f..96125ba 100644
--- a/imageinf/inference/models.py
+++ b/imageinf/inference/models.py
@@ -1,5 +1,5 @@
from pydantic import BaseModel
-from typing import List, Optional
+from typing import List, Optional, Literal
from datetime import datetime
@@ -35,10 +35,13 @@ class InferenceResponse(BaseModel):
model: str
aggregated_results: List[InferenceResult]
results: List[InferenceResult]
- metadata: Optional[ImageMetadata] = None
class InferenceRequest(BaseModel):
inferenceType: str = "classification"
files: List[TapisFile]
- model: str = "google/vit-base-patch16-224"
+ model: str = ("google/vit-base-patch16-224",)
+ labels: Optional[List[str]] = None # used in CLIP only
+ sensitivity: Optional[Literal["high", "medium", "low"]] = (
+ "medium" # used in CLIP only
+ )
diff --git a/imageinf/inference/processor.py b/imageinf/inference/processor.py
index 12a4dde..eeaeb80 100644
--- a/imageinf/inference/processor.py
+++ b/imageinf/inference/processor.py
@@ -1,11 +1,11 @@
-from typing import List
+from typing import List, Optional
from tapipy.tapis import Tapis
from imageinf.utils.auth import TapisUser
from imageinf.utils.io import get_image_file
from .config import DEFAULT_MODEL_NAME
-from .registry import MODEL_REGISTRY
+from .registry import MODEL_REGISTRY, MODEL_METADATA
from .categories import aggregate_predictions
from .models import TapisFile, InferenceResult, InferenceResponse
@@ -16,12 +16,24 @@
# Public interface: plugin dispatch
def run_model_on_tapis_images(
- files: List[TapisFile], user: TapisUser, model_name: str = DEFAULT_MODEL_NAME
+ files: List[TapisFile],
+ user: TapisUser,
+ model_name: str = DEFAULT_MODEL_NAME,
+ labels: Optional[List[str]] = None, # only for CLIP
+ sensitivity: str = "medium", # only for CLIP
) -> InferenceResponse:
if model_name not in MODEL_REGISTRY:
raise ValueError(f"Model '{model_name}' is not supported.")
+
+ model_meta = MODEL_METADATA[model_name]
ModelClass = MODEL_REGISTRY[model_name]
- model = ModelClass(model_name)
+
+ if model_meta["type"] == "clip":
+ # pass labels for CLIP
+ model = ModelClass(model_name, labels=labels)
+ else:
+ model = ModelClass(model_name)
+
tapis = Tapis(base_url=user.tenant_host, access_token=user.tapis_token)
results = []
@@ -30,7 +42,11 @@ def run_model_on_tapis_images(
for file in files:
try:
image, metadata = get_image_file(tapis, file.systemId, file.path)
- predictions = model.classify_image(image)
+
+ if model_meta["type"] == "clip":
+ predictions = model.classify_image(image, sensitivity=sensitivity)
+ else:
+ predictions = model.classify_image(image)
# Always create detailed results
results.append(
@@ -42,19 +58,18 @@ def run_model_on_tapis_images(
)
)
- # Always create aggregated results (skip for CLIP)
- if "clip" not in model_name.lower():
- aggregated = aggregate_predictions(predictions)
+ if model_meta["type"] == "clip":
+ # For CLIP, just copy the results since it's already aggregated
aggregated_results.append(
InferenceResult(
- systemId=file.systemId, path=file.path, predictions=aggregated
+ systemId=file.systemId, path=file.path, predictions=predictions
)
)
else:
- # For CLIP, just copy the results since it's already aggregated
+ aggregated = aggregate_predictions(predictions)
aggregated_results.append(
InferenceResult(
- systemId=file.systemId, path=file.path, predictions=predictions
+ systemId=file.systemId, path=file.path, predictions=aggregated
)
)
diff --git a/imageinf/inference/registry.py b/imageinf/inference/registry.py
index b283ec6..8d48630 100644
--- a/imageinf/inference/registry.py
+++ b/imageinf/inference/registry.py
@@ -2,11 +2,12 @@
MODEL_METADATA = {}
-def register_model_runner(model_name, description=None, link=None):
+def register_model_runner(model_name, model_type, description=None, link=None):
def decorator(cls):
MODEL_REGISTRY[model_name] = cls
MODEL_METADATA[model_name] = {
"name": model_name,
+ "type": model_type,
"description": description or model_name,
"link": link or "",
}
diff --git a/imageinf/inference/routes.py b/imageinf/inference/routes.py
index dc3755f..8528098 100644
--- a/imageinf/inference/routes.py
+++ b/imageinf/inference/routes.py
@@ -43,6 +43,12 @@ def run_sync_inference(
raise HTTPException(400, detail="Too many files. Use async endpoint for >5.")
try:
- return run_model_on_tapis_images(request.files, user, request.model)
+ return run_model_on_tapis_images(
+ request.files,
+ user,
+ request.model,
+ labels=request.labels,
+ sensitivity=request.sensitivity,
+ )
except ValueError as e:
raise HTTPException(400, detail=str(e))
diff --git a/imageinf/inference/vit_models.py b/imageinf/inference/vit_models.py
index 953e80b..af139c6 100644
--- a/imageinf/inference/vit_models.py
+++ b/imageinf/inference/vit_models.py
@@ -4,6 +4,7 @@
@register_model_runner(
"google/vit-base-patch16-224",
+ model_type="vit",
description="Vision Transformer (ViT) base model - 86M params, 224x224",
link="https://huggingface.co/google/vit-base-patch16-224",
)
@@ -13,6 +14,7 @@ class ViTBaseModel(TransformerModel):
@register_model_runner(
"google/vit-large-patch16-224",
+ model_type="vit",
description="Vision Transformer (ViT) large model - 304M params, 224x224",
link="https://huggingface.co/google/vit-large-patch16-224",
)
@@ -22,6 +24,7 @@ class ViTLargeModel(TransformerModel):
@register_model_runner(
"google/vit-large-patch16-384",
+ model_type="vit",
description=(
"Vision Transformer (ViT) large model - 304M params, 384x384 (high res)"
),
@@ -43,6 +46,7 @@ class ViTHugeModel(TransformerModel):
@register_model_runner(
"microsoft/swin-large-patch4-window7-224",
+ model_type="vit",
description="Swin Transformer large - 197M params, 224x224",
link="https://huggingface.co/microsoft/swin-large-patch4-window7-224",
)
diff --git a/requirements.txt b/requirements.txt
index 00f692f..4f23f97 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,8 +4,8 @@ uvicorn[standard]
httpx
tapipy
Pillow
-torch
-transformers
+torch>=2.10,<3
+transformers>=4.57,<5
huggingface_hub[hf_xet]
pytest
pytest-cov