PlanOpticon

planopticon / video_processor / providers / vertex_provider.py
Source Blame History 226 lines
0981a08… noreply 1 """Google Vertex AI provider implementation."""
0981a08… noreply 2
0981a08… noreply 3 import logging
0981a08… noreply 4 import os
0981a08… noreply 5 from pathlib import Path
0981a08… noreply 6 from typing import Optional
0981a08… noreply 7
0981a08… noreply 8 from dotenv import load_dotenv
0981a08… noreply 9
0981a08… noreply 10 from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry
0981a08… noreply 11
0981a08… noreply 12 load_dotenv()
0981a08… noreply 13 logger = logging.getLogger(__name__)
0981a08… noreply 14
0981a08… noreply 15 # Curated list of models available on Vertex AI
0981a08… noreply 16 _VERTEX_MODELS = [
0981a08… noreply 17 ModelInfo(
0981a08… noreply 18 id="gemini-2.0-flash",
0981a08… noreply 19 provider="vertex",
0981a08… noreply 20 display_name="Gemini 2.0 Flash",
0981a08… noreply 21 capabilities=["chat", "vision", "audio"],
0981a08… noreply 22 ),
0981a08… noreply 23 ModelInfo(
0981a08… noreply 24 id="gemini-2.0-pro",
0981a08… noreply 25 provider="vertex",
0981a08… noreply 26 display_name="Gemini 2.0 Pro",
0981a08… noreply 27 capabilities=["chat", "vision", "audio"],
0981a08… noreply 28 ),
0981a08… noreply 29 ModelInfo(
0981a08… noreply 30 id="gemini-1.5-pro",
0981a08… noreply 31 provider="vertex",
0981a08… noreply 32 display_name="Gemini 1.5 Pro",
0981a08… noreply 33 capabilities=["chat", "vision", "audio"],
0981a08… noreply 34 ),
0981a08… noreply 35 ModelInfo(
0981a08… noreply 36 id="gemini-1.5-flash",
0981a08… noreply 37 provider="vertex",
0981a08… noreply 38 display_name="Gemini 1.5 Flash",
0981a08… noreply 39 capabilities=["chat", "vision", "audio"],
0981a08… noreply 40 ),
0981a08… noreply 41 ]
0981a08… noreply 42
0981a08… noreply 43
0981a08… noreply 44 class VertexProvider(BaseProvider):
0981a08… noreply 45 """Google Vertex AI provider using google-genai SDK with Vertex config."""
0981a08… noreply 46
0981a08… noreply 47 provider_name = "vertex"
0981a08… noreply 48
0981a08… noreply 49 def __init__(
0981a08… noreply 50 self,
0981a08… noreply 51 project: Optional[str] = None,
0981a08… noreply 52 location: Optional[str] = None,
0981a08… noreply 53 ):
0981a08… noreply 54 try:
0981a08… noreply 55 from google import genai
0981a08… noreply 56 from google.genai import types # noqa: F401
0981a08… noreply 57 except ImportError:
0981a08… noreply 58 raise ImportError(
0981a08… noreply 59 "google-cloud-aiplatform or google-genai package not installed. "
0981a08… noreply 60 "Install with: pip install google-cloud-aiplatform"
0981a08… noreply 61 )
0981a08… noreply 62
0981a08… noreply 63 self._genai = genai
0981a08… noreply 64 self._project = project or os.getenv("GOOGLE_CLOUD_PROJECT")
0981a08… noreply 65 self._location = location or os.getenv("GOOGLE_CLOUD_REGION", "us-central1")
0981a08… noreply 66
0981a08… noreply 67 if not self._project:
0981a08… noreply 68 raise ValueError("GOOGLE_CLOUD_PROJECT not set")
0981a08… noreply 69
0981a08… noreply 70 self.client = genai.Client(
0981a08… noreply 71 vertexai=True,
0981a08… noreply 72 project=self._project,
0981a08… noreply 73 location=self._location,
0981a08… noreply 74 )
0981a08… noreply 75 self._last_usage = {}
0981a08… noreply 76
0981a08… noreply 77 def chat(
0981a08… noreply 78 self,
0981a08… noreply 79 messages: list[dict],
0981a08… noreply 80 max_tokens: int = 4096,
0981a08… noreply 81 temperature: float = 0.7,
0981a08… noreply 82 model: Optional[str] = None,
0981a08… noreply 83 ) -> str:
0981a08… noreply 84 from google.genai import types
0981a08… noreply 85
0981a08… noreply 86 model = model or "gemini-2.0-flash"
0981a08… noreply 87 if model.startswith("vertex/"):
0981a08… noreply 88 model = model[len("vertex/") :]
0981a08… noreply 89
0981a08… noreply 90 contents = []
0981a08… noreply 91 for msg in messages:
0981a08… noreply 92 role = "user" if msg["role"] == "user" else "model"
0981a08… noreply 93 contents.append(
0981a08… noreply 94 types.Content(
0981a08… noreply 95 role=role,
0981a08… noreply 96 parts=[types.Part.from_text(text=msg["content"])],
0981a08… noreply 97 )
0981a08… noreply 98 )
0981a08… noreply 99
0981a08… noreply 100 response = self.client.models.generate_content(
0981a08… noreply 101 model=model,
0981a08… noreply 102 contents=contents,
0981a08… noreply 103 config=types.GenerateContentConfig(
0981a08… noreply 104 max_output_tokens=max_tokens,
0981a08… noreply 105 temperature=temperature,
0981a08… noreply 106 ),
0981a08… noreply 107 )
0981a08… noreply 108 um = getattr(response, "usage_metadata", None)
0981a08… noreply 109 self._last_usage = {
0981a08… noreply 110 "input_tokens": getattr(um, "prompt_token_count", 0) if um else 0,
0981a08… noreply 111 "output_tokens": getattr(um, "candidates_token_count", 0) if um else 0,
0981a08… noreply 112 }
0981a08… noreply 113 return response.text or ""
0981a08… noreply 114
0981a08… noreply 115 def analyze_image(
0981a08… noreply 116 self,
0981a08… noreply 117 image_bytes: bytes,
0981a08… noreply 118 prompt: str,
0981a08… noreply 119 max_tokens: int = 4096,
0981a08… noreply 120 model: Optional[str] = None,
0981a08… noreply 121 ) -> str:
0981a08… noreply 122 from google.genai import types
0981a08… noreply 123
0981a08… noreply 124 model = model or "gemini-2.0-flash"
0981a08… noreply 125 if model.startswith("vertex/"):
0981a08… noreply 126 model = model[len("vertex/") :]
0981a08… noreply 127
0981a08… noreply 128 response = self.client.models.generate_content(
0981a08… noreply 129 model=model,
0981a08… noreply 130 contents=[
0981a08… noreply 131 types.Part.from_bytes(data=image_bytes, mime_type="image/jpeg"),
0981a08… noreply 132 prompt,
0981a08… noreply 133 ],
0981a08… noreply 134 config=types.GenerateContentConfig(
0981a08… noreply 135 max_output_tokens=max_tokens,
0981a08… noreply 136 ),
0981a08… noreply 137 )
0981a08… noreply 138 um = getattr(response, "usage_metadata", None)
0981a08… noreply 139 self._last_usage = {
0981a08… noreply 140 "input_tokens": getattr(um, "prompt_token_count", 0) if um else 0,
0981a08… noreply 141 "output_tokens": getattr(um, "candidates_token_count", 0) if um else 0,
0981a08… noreply 142 }
0981a08… noreply 143 return response.text or ""
0981a08… noreply 144
0981a08… noreply 145 def transcribe_audio(
0981a08… noreply 146 self,
0981a08… noreply 147 audio_path: str | Path,
0981a08… noreply 148 language: Optional[str] = None,
0981a08… noreply 149 model: Optional[str] = None,
0981a08… noreply 150 ) -> dict:
0981a08… noreply 151 import json
0981a08… noreply 152
0981a08… noreply 153 from google.genai import types
0981a08… noreply 154
0981a08… noreply 155 model = model or "gemini-2.0-flash"
0981a08… noreply 156 if model.startswith("vertex/"):
0981a08… noreply 157 model = model[len("vertex/") :]
0981a08… noreply 158
0981a08… noreply 159 audio_path = Path(audio_path)
0981a08… noreply 160 suffix = audio_path.suffix.lower()
0981a08… noreply 161 mime_map = {
0981a08… noreply 162 ".wav": "audio/wav",
0981a08… noreply 163 ".mp3": "audio/mpeg",
0981a08… noreply 164 ".m4a": "audio/mp4",
0981a08… noreply 165 ".flac": "audio/flac",
0981a08… noreply 166 ".ogg": "audio/ogg",
0981a08… noreply 167 ".webm": "audio/webm",
0981a08… noreply 168 }
0981a08… noreply 169 mime_type = mime_map.get(suffix, "audio/wav")
0981a08… noreply 170 audio_bytes = audio_path.read_bytes()
0981a08… noreply 171
0981a08… noreply 172 lang_hint = f" The audio is in {language}." if language else ""
0981a08… noreply 173 prompt = (
0981a08… noreply 174 f"Transcribe this audio accurately.{lang_hint} "
0981a08… noreply 175 "Return a JSON object with keys: "
0981a08… noreply 176 '"text" (full transcript), '
0981a08… noreply 177 '"segments" (array of {start, end, text} objects with timestamps in seconds).'
0981a08… noreply 178 )
0981a08… noreply 179
0981a08… noreply 180 response = self.client.models.generate_content(
0981a08… noreply 181 model=model,
0981a08… noreply 182 contents=[
0981a08… noreply 183 types.Part.from_bytes(data=audio_bytes, mime_type=mime_type),
0981a08… noreply 184 prompt,
0981a08… noreply 185 ],
0981a08… noreply 186 config=types.GenerateContentConfig(
0981a08… noreply 187 max_output_tokens=8192,
0981a08… noreply 188 response_mime_type="application/json",
0981a08… noreply 189 ),
0981a08… noreply 190 )
0981a08… noreply 191
0981a08… noreply 192 try:
0981a08… noreply 193 data = json.loads(response.text)
0981a08… noreply 194 except (json.JSONDecodeError, TypeError):
0981a08… noreply 195 data = {"text": response.text or "", "segments": []}
0981a08… noreply 196
0981a08… noreply 197 um = getattr(response, "usage_metadata", None)
0981a08… noreply 198 self._last_usage = {
0981a08… noreply 199 "input_tokens": getattr(um, "prompt_token_count", 0) if um else 0,
0981a08… noreply 200 "output_tokens": getattr(um, "candidates_token_count", 0) if um else 0,
0981a08… noreply 201 }
0981a08… noreply 202
0981a08… noreply 203 return {
0981a08… noreply 204 "text": data.get("text", ""),
0981a08… noreply 205 "segments": data.get("segments", []),
0981a08… noreply 206 "language": language,
0981a08… noreply 207 "duration": None,
0981a08… noreply 208 "provider": "vertex",
0981a08… noreply 209 "model": model,
0981a08… noreply 210 }
0981a08… noreply 211
0981a08… noreply 212 def list_models(self) -> list[ModelInfo]:
0981a08… noreply 213 return list(_VERTEX_MODELS)
0981a08… noreply 214
0981a08… noreply 215
0981a08… noreply 216 ProviderRegistry.register(
0981a08… noreply 217 name="vertex",
0981a08… noreply 218 provider_class=VertexProvider,
0981a08… noreply 219 env_var="GOOGLE_CLOUD_PROJECT",
0981a08… noreply 220 model_prefixes=["vertex/"],
0981a08… noreply 221 default_models={
0981a08… noreply 222 "chat": "gemini-2.0-flash",
0981a08… noreply 223 "vision": "gemini-2.0-flash",
0981a08… noreply 224 "audio": "gemini-2.0-flash",
0981a08… noreply 225 },
0981a08… noreply 226 )

Keyboard Shortcuts

Open search /
Next entry (timeline) j
Previous entry (timeline) k
Open focused entry Enter
Show this help ?
Toggle theme Top nav button