PlanOpticon
feat(providers): add Bedrock, Vertex, Mistral, Cohere, AI21, HuggingFace, Qianfan, LiteLLM providers
Commit
010bba04cbf15541818ce0ebcfa5d6b1ad7e81e37fbbdedb6e3c491a5e457034
Parent
9975609f4e77601…
9 files changed
+9
-1
+98
+193
+123
+187
+171
+167
+138
+226
~
pyproject.toml
+
video_processor/providers/ai21_provider.py
+
video_processor/providers/bedrock_provider.py
+
video_processor/providers/cohere_provider.py
+
video_processor/providers/huggingface_provider.py
+
video_processor/providers/litellm_provider.py
+
video_processor/providers/mistral_provider.py
+
video_processor/providers/qianfan_provider.py
+
video_processor/providers/vertex_provider.py
+9
-1
| --- pyproject.toml | ||
| +++ pyproject.toml | ||
| @@ -50,19 +50,27 @@ | ||
| 50 | 50 | "requests>=2.31.0", |
| 51 | 51 | "tenacity>=8.2.0", |
| 52 | 52 | ] |
| 53 | 53 | |
| 54 | 54 | [project.optional-dependencies] |
| 55 | -pdf = ["weasyprint>=60.0"] | |
| 55 | +pdf = ["pymupdf>=1.24.0"] | |
| 56 | 56 | gpu = ["torch>=2.0.0", "torchvision>=0.15.0"] |
| 57 | 57 | gdrive = ["google-auth>=2.0.0", "google-auth-oauthlib>=1.0.0", "google-api-python-client>=2.0.0"] |
| 58 | 58 | dropbox = ["dropbox>=12.0.0"] |
| 59 | 59 | azure = ["openai>=1.0.0"] |
| 60 | 60 | together = ["openai>=1.0.0"] |
| 61 | 61 | fireworks = ["openai>=1.0.0"] |
| 62 | 62 | cerebras = ["openai>=1.0.0"] |
| 63 | 63 | xai = ["openai>=1.0.0"] |
| 64 | +bedrock = ["boto3>=1.28"] | |
| 65 | +vertex = ["google-cloud-aiplatform>=1.38"] | |
| 66 | +mistral = ["mistralai>=1.0"] | |
| 67 | +cohere = ["cohere>=5.0"] | |
| 68 | +ai21 = ["ai21>=3.0"] | |
| 69 | +huggingface = ["huggingface_hub>=0.20"] | |
| 70 | +qianfan = ["qianfan>=0.4"] | |
| 71 | +litellm = ["litellm>=1.0"] | |
| 64 | 72 | graph = [] |
| 65 | 73 | cloud = [ |
| 66 | 74 | "planopticon[gdrive]", |
| 67 | 75 | "planopticon[dropbox]", |
| 68 | 76 | ] |
| 69 | 77 | |
| 70 | 78 | ADDED video_processor/providers/ai21_provider.py |
| 71 | 79 | ADDED video_processor/providers/bedrock_provider.py |
| 72 | 80 | ADDED video_processor/providers/cohere_provider.py |
| 73 | 81 | ADDED video_processor/providers/huggingface_provider.py |
| 74 | 82 | ADDED video_processor/providers/litellm_provider.py |
| 75 | 83 | ADDED video_processor/providers/mistral_provider.py |
| 76 | 84 | ADDED video_processor/providers/qianfan_provider.py |
| 77 | 85 | ADDED video_processor/providers/vertex_provider.py |
| --- pyproject.toml | |
| +++ pyproject.toml | |
| @@ -50,19 +50,27 @@ | |
| 50 | "requests>=2.31.0", |
| 51 | "tenacity>=8.2.0", |
| 52 | ] |
| 53 | |
| 54 | [project.optional-dependencies] |
| 55 | pdf = ["weasyprint>=60.0"] |
| 56 | gpu = ["torch>=2.0.0", "torchvision>=0.15.0"] |
| 57 | gdrive = ["google-auth>=2.0.0", "google-auth-oauthlib>=1.0.0", "google-api-python-client>=2.0.0"] |
| 58 | dropbox = ["dropbox>=12.0.0"] |
| 59 | azure = ["openai>=1.0.0"] |
| 60 | together = ["openai>=1.0.0"] |
| 61 | fireworks = ["openai>=1.0.0"] |
| 62 | cerebras = ["openai>=1.0.0"] |
| 63 | xai = ["openai>=1.0.0"] |
| 64 | graph = [] |
| 65 | cloud = [ |
| 66 | "planopticon[gdrive]", |
| 67 | "planopticon[dropbox]", |
| 68 | ] |
| 69 | |
| 70 | DDED video_processor/providers/ai21_provider.py |
| 71 | DDED video_processor/providers/bedrock_provider.py |
| 72 | DDED video_processor/providers/cohere_provider.py |
| 73 | DDED video_processor/providers/huggingface_provider.py |
| 74 | DDED video_processor/providers/litellm_provider.py |
| 75 | DDED video_processor/providers/mistral_provider.py |
| 76 | DDED video_processor/providers/qianfan_provider.py |
| 77 | DDED video_processor/providers/vertex_provider.py |
| --- pyproject.toml | |
| +++ pyproject.toml | |
| @@ -50,19 +50,27 @@ | |
| 50 | "requests>=2.31.0", |
| 51 | "tenacity>=8.2.0", |
| 52 | ] |
| 53 | |
| 54 | [project.optional-dependencies] |
| 55 | pdf = ["pymupdf>=1.24.0"] |
| 56 | gpu = ["torch>=2.0.0", "torchvision>=0.15.0"] |
| 57 | gdrive = ["google-auth>=2.0.0", "google-auth-oauthlib>=1.0.0", "google-api-python-client>=2.0.0"] |
| 58 | dropbox = ["dropbox>=12.0.0"] |
| 59 | azure = ["openai>=1.0.0"] |
| 60 | together = ["openai>=1.0.0"] |
| 61 | fireworks = ["openai>=1.0.0"] |
| 62 | cerebras = ["openai>=1.0.0"] |
| 63 | xai = ["openai>=1.0.0"] |
| 64 | bedrock = ["boto3>=1.28"] |
| 65 | vertex = ["google-cloud-aiplatform>=1.38"] |
| 66 | mistral = ["mistralai>=1.0"] |
| 67 | cohere = ["cohere>=5.0"] |
| 68 | ai21 = ["ai21>=3.0"] |
| 69 | huggingface = ["huggingface_hub>=0.20"] |
| 70 | qianfan = ["qianfan>=0.4"] |
| 71 | litellm = ["litellm>=1.0"] |
| 72 | graph = [] |
| 73 | cloud = [ |
| 74 | "planopticon[gdrive]", |
| 75 | "planopticon[dropbox]", |
| 76 | ] |
| 77 | |
| 78 | DDED video_processor/providers/ai21_provider.py |
| 79 | DDED video_processor/providers/bedrock_provider.py |
| 80 | DDED video_processor/providers/cohere_provider.py |
| 81 | DDED video_processor/providers/huggingface_provider.py |
| 82 | DDED video_processor/providers/litellm_provider.py |
| 83 | DDED video_processor/providers/mistral_provider.py |
| 84 | DDED video_processor/providers/qianfan_provider.py |
| 85 | DDED video_processor/providers/vertex_provider.py |
| --- a/video_processor/providers/ai21_provider.py | ||
| +++ b/video_processor/providers/ai21_provider.py | ||
| @@ -0,0 +1,98 @@ | ||
| 1 | +"""AI21 Labs provider implementation.""" | |
| 2 | + | |
| 3 | +import logging | |
| 4 | +import os | |
| 5 | +from pathlib import Path | |
| 6 | +from typing import Optional | |
| 7 | + | |
| 8 | +from dotenv import load_dotenv | |
| 9 | + | |
| 10 | +from video_processor.providers.base import ModelInfo, OpenAICompatibleProvider, ProviderRegistry | |
| 11 | + | |
| 12 | +load_dotenv() | |
| 13 | +logger = logging.getLogger(__name__) | |
| 14 | + | |
| 15 | +# Curated list of AI21 models | |
| 16 | +_AI21_MODELS = [ | |
| 17 | + ModelInfo( | |
| 18 | + id="jamba-1.5-large", | |
| 19 | + provider="ai21", | |
| 20 | + display_name="Jamba 1.5 Large", | |
| 21 | + capabilities=["chat"], | |
| 22 | + ), | |
| 23 | + ModelInfo( | |
| 24 | + id="jamba-1.5-mini", | |
| 25 | + provider="ai21", | |
| 26 | + display_name="Jamba 1.5 Mini", | |
| 27 | + capabilities=["chat"], | |
| 28 | + ), | |
| 29 | + ModelInfo( | |
| 30 | + id="jamba-instruct", | |
| 31 | + provider="ai21", | |
| 32 | + display_name="Jamba Instruct", | |
| 33 | + capabilities=["chat"], | |
| 34 | + ), | |
| 35 | +] | |
| 36 | + | |
| 37 | + | |
| 38 | +class AI21Provider(OpenAICompatibleProvider): | |
| 39 | + """AI21 Labs provider using OpenAI-compatible API.""" | |
| 40 | + | |
| 41 | + provider_name = "ai21" | |
| 42 | + base_url = "https://api.ai21.com/studio/v1" | |
| 43 | + env_var = "AI21_API_KEY" | |
| 44 | + | |
| 45 | + def __init__(self, api_key: Optional[str] = None): | |
| 46 | + api_key = api_key or os.getenv("AI21_API_KEY") | |
| 47 | + if not api_key: | |
| 48 | + raise ValueError("AI21_API_KEY not set") | |
| 49 | + super().__init__(api_key=api_key, base_url=self.base_url) | |
| 50 | + | |
| 51 | + def chat( | |
| 52 | + self, | |
| 53 | + messages: list[dict], | |
| 54 | + max_tokens: int = 4096, | |
| 55 | + temperature: float = 0.7, | |
| 56 | + model: Optional[str] = None, | |
| 57 | + ) -> str: | |
| 58 | + model = model or "jamba-1.5-large" | |
| 59 | + return super().chat(messages, max_tokens, temperature, model) | |
| 60 | + | |
| 61 | + def analyze_image( | |
| 62 | + self, | |
| 63 | + image_bytes: bytes, | |
| 64 | + prompt: str, | |
| 65 | + max_tokens: int = 4096, | |
| 66 | + model: Optional[str] = None, | |
| 67 | + ) -> str: | |
| 68 | + raise NotImplementedError( | |
| 69 | + "AI21 does not currently support vision/image analysis. " | |
| 70 | + "Use OpenAI, Anthropic, or Gemini for image analysis." | |
| 71 | + ) | |
| 72 | + | |
| 73 | + def transcribe_audio( | |
| 74 | + self, | |
| 75 | + audio_path: str | Path, | |
| 76 | + language: Optional[str] = None, | |
| 77 | + model: Optional[str] = None, | |
| 78 | + ) -> dict: | |
| 79 | + raise NotImplementedError( | |
| 80 | + "AI21 does not provide a transcription API. " | |
| 81 | + "Use OpenAI Whisper or Gemini for transcription." | |
| 82 | + ) | |
| 83 | + | |
| 84 | + def list_models(self) -> list[ModelInfo]: | |
| 85 | + return list(_AI21_MODELS) | |
| 86 | + | |
| 87 | + | |
| 88 | +ProviderRegistry.register( | |
| 89 | + name="ai21", | |
| 90 | + provider_class=AI21Provider, | |
| 91 | + env_var="AI21_API_KEY", | |
| 92 | + model_prefixes=["jamba-", "j2-"], | |
| 93 | + default_models={ | |
| 94 | + "chat": "jamba-1.5-large", | |
| 95 | + "vision": "", | |
| 96 | + "audio": "", | |
| 97 | + }, | |
| 98 | +) |
| --- a/video_processor/providers/ai21_provider.py | |
| +++ b/video_processor/providers/ai21_provider.py | |
| @@ -0,0 +1,98 @@ | |
| --- a/video_processor/providers/ai21_provider.py | |
| +++ b/video_processor/providers/ai21_provider.py | |
| @@ -0,0 +1,98 @@ | |
| 1 | """AI21 Labs provider implementation.""" |
| 2 | |
| 3 | import logging |
| 4 | import os |
| 5 | from pathlib import Path |
| 6 | from typing import Optional |
| 7 | |
| 8 | from dotenv import load_dotenv |
| 9 | |
| 10 | from video_processor.providers.base import ModelInfo, OpenAICompatibleProvider, ProviderRegistry |
| 11 | |
| 12 | load_dotenv() |
| 13 | logger = logging.getLogger(__name__) |
| 14 | |
| 15 | # Curated list of AI21 models |
| 16 | _AI21_MODELS = [ |
| 17 | ModelInfo( |
| 18 | id="jamba-1.5-large", |
| 19 | provider="ai21", |
| 20 | display_name="Jamba 1.5 Large", |
| 21 | capabilities=["chat"], |
| 22 | ), |
| 23 | ModelInfo( |
| 24 | id="jamba-1.5-mini", |
| 25 | provider="ai21", |
| 26 | display_name="Jamba 1.5 Mini", |
| 27 | capabilities=["chat"], |
| 28 | ), |
| 29 | ModelInfo( |
| 30 | id="jamba-instruct", |
| 31 | provider="ai21", |
| 32 | display_name="Jamba Instruct", |
| 33 | capabilities=["chat"], |
| 34 | ), |
| 35 | ] |
| 36 | |
| 37 | |
| 38 | class AI21Provider(OpenAICompatibleProvider): |
| 39 | """AI21 Labs provider using OpenAI-compatible API.""" |
| 40 | |
| 41 | provider_name = "ai21" |
| 42 | base_url = "https://api.ai21.com/studio/v1" |
| 43 | env_var = "AI21_API_KEY" |
| 44 | |
| 45 | def __init__(self, api_key: Optional[str] = None): |
| 46 | api_key = api_key or os.getenv("AI21_API_KEY") |
| 47 | if not api_key: |
| 48 | raise ValueError("AI21_API_KEY not set") |
| 49 | super().__init__(api_key=api_key, base_url=self.base_url) |
| 50 | |
| 51 | def chat( |
| 52 | self, |
| 53 | messages: list[dict], |
| 54 | max_tokens: int = 4096, |
| 55 | temperature: float = 0.7, |
| 56 | model: Optional[str] = None, |
| 57 | ) -> str: |
| 58 | model = model or "jamba-1.5-large" |
| 59 | return super().chat(messages, max_tokens, temperature, model) |
| 60 | |
| 61 | def analyze_image( |
| 62 | self, |
| 63 | image_bytes: bytes, |
| 64 | prompt: str, |
| 65 | max_tokens: int = 4096, |
| 66 | model: Optional[str] = None, |
| 67 | ) -> str: |
| 68 | raise NotImplementedError( |
| 69 | "AI21 does not currently support vision/image analysis. " |
| 70 | "Use OpenAI, Anthropic, or Gemini for image analysis." |
| 71 | ) |
| 72 | |
| 73 | def transcribe_audio( |
| 74 | self, |
| 75 | audio_path: str | Path, |
| 76 | language: Optional[str] = None, |
| 77 | model: Optional[str] = None, |
| 78 | ) -> dict: |
| 79 | raise NotImplementedError( |
| 80 | "AI21 does not provide a transcription API. " |
| 81 | "Use OpenAI Whisper or Gemini for transcription." |
| 82 | ) |
| 83 | |
| 84 | def list_models(self) -> list[ModelInfo]: |
| 85 | return list(_AI21_MODELS) |
| 86 | |
| 87 | |
| 88 | ProviderRegistry.register( |
| 89 | name="ai21", |
| 90 | provider_class=AI21Provider, |
| 91 | env_var="AI21_API_KEY", |
| 92 | model_prefixes=["jamba-", "j2-"], |
| 93 | default_models={ |
| 94 | "chat": "jamba-1.5-large", |
| 95 | "vision": "", |
| 96 | "audio": "", |
| 97 | }, |
| 98 | ) |
| --- a/video_processor/providers/bedrock_provider.py | ||
| +++ b/video_processor/providers/bedrock_provider.py | ||
| @@ -0,0 +1,193 @@ | ||
| 1 | +"""AWS Bedrock provider implementation.""" | |
| 2 | + | |
| 3 | +import base64 | |
| 4 | +import json | |
| 5 | +import logging | |
| 6 | +import os | |
| 7 | +from pathlib import Path | |
| 8 | +from typing import Optional | |
| 9 | + | |
| 10 | +from dotenv import load_dotenv | |
| 11 | + | |
| 12 | +from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry | |
| 13 | + | |
| 14 | +load_dotenv() | |
| 15 | +logger = logging.getLogger(__name__) | |
| 16 | + | |
| 17 | +# Curated list of popular Bedrock models | |
| 18 | +_BEDROCK_MODELS = [ | |
| 19 | + ModelInfo( | |
| 20 | + id="anthropic.claude-3-5-sonnet-20241022-v2:0", | |
| 21 | + provider="bedrock", | |
| 22 | + display_name="Claude 3.5 Sonnet v2", | |
| 23 | + capabilities=["chat", "vision"], | |
| 24 | + ), | |
| 25 | + ModelInfo( | |
| 26 | + id="anthropic.claude-3-sonnet-20240229-v1:0", | |
| 27 | + provider="bedrock", | |
| 28 | + display_name="Claude 3 Sonnet", | |
| 29 | + capabilities=["chat", "vision"], | |
| 30 | + ), | |
| 31 | + ModelInfo( | |
| 32 | + id="anthropic.claude-3-haiku-20240307-v1:0", | |
| 33 | + provider="bedrock", | |
| 34 | + display_name="Claude 3 Haiku", | |
| 35 | + capabilities=["chat", "vision"], | |
| 36 | + ), | |
| 37 | + ModelInfo( | |
| 38 | + id="amazon.titan-text-express-v1", | |
| 39 | + provider="bedrock", | |
| 40 | + display_name="Amazon Titan Text Express", | |
| 41 | + capabilities=["chat"], | |
| 42 | + ), | |
| 43 | + ModelInfo( | |
| 44 | + id="meta.llama3-70b-instruct-v1:0", | |
| 45 | + provider="bedrock", | |
| 46 | + display_name="Llama 3 70B Instruct", | |
| 47 | + capabilities=["chat"], | |
| 48 | + ), | |
| 49 | + ModelInfo( | |
| 50 | + id="mistral.mistral-large-2402-v1:0", | |
| 51 | + provider="bedrock", | |
| 52 | + display_name="Mistral Large", | |
| 53 | + capabilities=["chat"], | |
| 54 | + ), | |
| 55 | +] | |
| 56 | + | |
| 57 | + | |
| 58 | +class BedrockProvider(BaseProvider): | |
| 59 | + """AWS Bedrock provider using boto3.""" | |
| 60 | + | |
| 61 | + provider_name = "bedrock" | |
| 62 | + | |
| 63 | + def __init__( | |
| 64 | + self, | |
| 65 | + aws_access_key_id: Optional[str] = None, | |
| 66 | + aws_secret_access_key: Optional[str] = None, | |
| 67 | + region_name: Optional[str] = None, | |
| 68 | + ): | |
| 69 | + try: | |
| 70 | + import boto3 | |
| 71 | + except ImportError: | |
| 72 | + raise ImportError("boto3 package not installed. Install with: pip install boto3") | |
| 73 | + | |
| 74 | + self._boto3 = boto3 | |
| 75 | + self._region = region_name or os.getenv("AWS_DEFAULT_REGION", "us-east-1") | |
| 76 | + self._client = boto3.client( | |
| 77 | + "bedrock-runtime", | |
| 78 | + aws_access_key_id=aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"), | |
| 79 | + aws_secret_access_key=aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY"), | |
| 80 | + region_name=self._region, | |
| 81 | + ) | |
| 82 | + self._last_usage = {} | |
| 83 | + | |
| 84 | + def chat( | |
| 85 | + self, | |
| 86 | + messages: list[dict], | |
| 87 | + max_tokens: int = 4096, | |
| 88 | + temperature: float = 0.7, | |
| 89 | + model: Optional[str] = None, | |
| 90 | + ) -> str: | |
| 91 | + model = model or "anthropic.claude-3-sonnet-20240229-v1:0" | |
| 92 | + # Strip bedrock/ prefix if present | |
| 93 | + if model.startswith("bedrock/"): | |
| 94 | + model = model[len("bedrock/") :] | |
| 95 | + | |
| 96 | + body = json.dumps( | |
| 97 | + { | |
| 98 | + "anthropic_version": "bedrock-2023-05-31", | |
| 99 | + "max_tokens": max_tokens, | |
| 100 | + "temperature": temperature, | |
| 101 | + "messages": messages, | |
| 102 | + } | |
| 103 | + ) | |
| 104 | + | |
| 105 | + response = self._client.invoke_model( | |
| 106 | + modelId=model, | |
| 107 | + contentType="application/json", | |
| 108 | + accept="application/json", | |
| 109 | + body=body, | |
| 110 | + ) | |
| 111 | + | |
| 112 | + result = json.loads(response["body"].read()) | |
| 113 | + self._last_usage = { | |
| 114 | + "input_tokens": result.get("usage", {}).get("input_tokens", 0), | |
| 115 | + "output_tokens": result.get("usage", {}).get("output_tokens", 0), | |
| 116 | + } | |
| 117 | + return result.get("content", [{}])[0].get("text", "") | |
| 118 | + | |
| 119 | + def analyze_image( | |
| 120 | + self, | |
| 121 | + image_bytes: bytes, | |
| 122 | + prompt: str, | |
| 123 | + max_tokens: int = 4096, | |
| 124 | + model: Optional[str] = None, | |
| 125 | + ) -> str: | |
| 126 | + model = model or "anthropic.claude-3-sonnet-20240229-v1:0" | |
| 127 | + if model.startswith("bedrock/"): | |
| 128 | + model = model[len("bedrock/") :] | |
| 129 | + | |
| 130 | + b64 = base64.b64encode(image_bytes).decode() | |
| 131 | + body = json.dumps( | |
| 132 | + { | |
| 133 | + "anthropic_version": "bedrock-2023-05-31", | |
| 134 | + "max_tokens": max_tokens, | |
| 135 | + "messages": [ | |
| 136 | + { | |
| 137 | + "role": "user", | |
| 138 | + "content": [ | |
| 139 | + { | |
| 140 | + "type": "image", | |
| 141 | + "source": { | |
| 142 | + "type": "base64", | |
| 143 | + "media_type": "image/jpeg", | |
| 144 | + "data": b64, | |
| 145 | + }, | |
| 146 | + }, | |
| 147 | + {"type": "text", "text": prompt}, | |
| 148 | + ], | |
| 149 | + } | |
| 150 | + ], | |
| 151 | + } | |
| 152 | + ) | |
| 153 | + | |
| 154 | + response = self._client.invoke_model( | |
| 155 | + modelId=model, | |
| 156 | + contentType="application/json", | |
| 157 | + accept="application/json", | |
| 158 | + body=body, | |
| 159 | + ) | |
| 160 | + | |
| 161 | + result = json.loads(response["body"].read()) | |
| 162 | + self._last_usage = { | |
| 163 | + "input_tokens": result.get("usage", {}).get("input_tokens", 0), | |
| 164 | + "output_tokens": result.get("usage", {}).get("output_tokens", 0), | |
| 165 | + } | |
| 166 | + return result.get("content", [{}])[0].get("text", "") | |
| 167 | + | |
| 168 | + def transcribe_audio( | |
| 169 | + self, | |
| 170 | + audio_path: str | Path, | |
| 171 | + language: Optional[str] = None, | |
| 172 | + model: Optional[str] = None, | |
| 173 | + ) -> dict: | |
| 174 | + raise NotImplementedError( | |
| 175 | + "AWS Bedrock does not support audio transcription directly. " | |
| 176 | + "Use Amazon Transcribe or another provider for transcription." | |
| 177 | + ) | |
| 178 | + | |
| 179 | + def list_models(self) -> list[ModelInfo]: | |
| 180 | + return list(_BEDROCK_MODELS) | |
| 181 | + | |
| 182 | + | |
| 183 | +ProviderRegistry.register( | |
| 184 | + name="bedrock", | |
| 185 | + provider_class=BedrockProvider, | |
| 186 | + env_var="AWS_ACCESS_KEY_ID", | |
| 187 | + model_prefixes=["bedrock/"], | |
| 188 | + default_models={ | |
| 189 | + "chat": "anthropic.claude-3-sonnet-20240229-v1:0", | |
| 190 | + "vision": "anthropic.claude-3-sonnet-20240229-v1:0", | |
| 191 | + "audio": "", | |
| 192 | + }, | |
| 193 | +) |
| --- a/video_processor/providers/bedrock_provider.py | |
| +++ b/video_processor/providers/bedrock_provider.py | |
| @@ -0,0 +1,193 @@ | |
| --- a/video_processor/providers/bedrock_provider.py | |
| +++ b/video_processor/providers/bedrock_provider.py | |
| @@ -0,0 +1,193 @@ | |
| 1 | """AWS Bedrock provider implementation.""" |
| 2 | |
| 3 | import base64 |
| 4 | import json |
| 5 | import logging |
| 6 | import os |
| 7 | from pathlib import Path |
| 8 | from typing import Optional |
| 9 | |
| 10 | from dotenv import load_dotenv |
| 11 | |
| 12 | from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry |
| 13 | |
| 14 | load_dotenv() |
| 15 | logger = logging.getLogger(__name__) |
| 16 | |
| 17 | # Curated list of popular Bedrock models |
| 18 | _BEDROCK_MODELS = [ |
| 19 | ModelInfo( |
| 20 | id="anthropic.claude-3-5-sonnet-20241022-v2:0", |
| 21 | provider="bedrock", |
| 22 | display_name="Claude 3.5 Sonnet v2", |
| 23 | capabilities=["chat", "vision"], |
| 24 | ), |
| 25 | ModelInfo( |
| 26 | id="anthropic.claude-3-sonnet-20240229-v1:0", |
| 27 | provider="bedrock", |
| 28 | display_name="Claude 3 Sonnet", |
| 29 | capabilities=["chat", "vision"], |
| 30 | ), |
| 31 | ModelInfo( |
| 32 | id="anthropic.claude-3-haiku-20240307-v1:0", |
| 33 | provider="bedrock", |
| 34 | display_name="Claude 3 Haiku", |
| 35 | capabilities=["chat", "vision"], |
| 36 | ), |
| 37 | ModelInfo( |
| 38 | id="amazon.titan-text-express-v1", |
| 39 | provider="bedrock", |
| 40 | display_name="Amazon Titan Text Express", |
| 41 | capabilities=["chat"], |
| 42 | ), |
| 43 | ModelInfo( |
| 44 | id="meta.llama3-70b-instruct-v1:0", |
| 45 | provider="bedrock", |
| 46 | display_name="Llama 3 70B Instruct", |
| 47 | capabilities=["chat"], |
| 48 | ), |
| 49 | ModelInfo( |
| 50 | id="mistral.mistral-large-2402-v1:0", |
| 51 | provider="bedrock", |
| 52 | display_name="Mistral Large", |
| 53 | capabilities=["chat"], |
| 54 | ), |
| 55 | ] |
| 56 | |
| 57 | |
| 58 | class BedrockProvider(BaseProvider): |
| 59 | """AWS Bedrock provider using boto3.""" |
| 60 | |
| 61 | provider_name = "bedrock" |
| 62 | |
| 63 | def __init__( |
| 64 | self, |
| 65 | aws_access_key_id: Optional[str] = None, |
| 66 | aws_secret_access_key: Optional[str] = None, |
| 67 | region_name: Optional[str] = None, |
| 68 | ): |
| 69 | try: |
| 70 | import boto3 |
| 71 | except ImportError: |
| 72 | raise ImportError("boto3 package not installed. Install with: pip install boto3") |
| 73 | |
| 74 | self._boto3 = boto3 |
| 75 | self._region = region_name or os.getenv("AWS_DEFAULT_REGION", "us-east-1") |
| 76 | self._client = boto3.client( |
| 77 | "bedrock-runtime", |
| 78 | aws_access_key_id=aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"), |
| 79 | aws_secret_access_key=aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY"), |
| 80 | region_name=self._region, |
| 81 | ) |
| 82 | self._last_usage = {} |
| 83 | |
| 84 | def chat( |
| 85 | self, |
| 86 | messages: list[dict], |
| 87 | max_tokens: int = 4096, |
| 88 | temperature: float = 0.7, |
| 89 | model: Optional[str] = None, |
| 90 | ) -> str: |
| 91 | model = model or "anthropic.claude-3-sonnet-20240229-v1:0" |
| 92 | # Strip bedrock/ prefix if present |
| 93 | if model.startswith("bedrock/"): |
| 94 | model = model[len("bedrock/") :] |
| 95 | |
| 96 | body = json.dumps( |
| 97 | { |
| 98 | "anthropic_version": "bedrock-2023-05-31", |
| 99 | "max_tokens": max_tokens, |
| 100 | "temperature": temperature, |
| 101 | "messages": messages, |
| 102 | } |
| 103 | ) |
| 104 | |
| 105 | response = self._client.invoke_model( |
| 106 | modelId=model, |
| 107 | contentType="application/json", |
| 108 | accept="application/json", |
| 109 | body=body, |
| 110 | ) |
| 111 | |
| 112 | result = json.loads(response["body"].read()) |
| 113 | self._last_usage = { |
| 114 | "input_tokens": result.get("usage", {}).get("input_tokens", 0), |
| 115 | "output_tokens": result.get("usage", {}).get("output_tokens", 0), |
| 116 | } |
| 117 | return result.get("content", [{}])[0].get("text", "") |
| 118 | |
| 119 | def analyze_image( |
| 120 | self, |
| 121 | image_bytes: bytes, |
| 122 | prompt: str, |
| 123 | max_tokens: int = 4096, |
| 124 | model: Optional[str] = None, |
| 125 | ) -> str: |
| 126 | model = model or "anthropic.claude-3-sonnet-20240229-v1:0" |
| 127 | if model.startswith("bedrock/"): |
| 128 | model = model[len("bedrock/") :] |
| 129 | |
| 130 | b64 = base64.b64encode(image_bytes).decode() |
| 131 | body = json.dumps( |
| 132 | { |
| 133 | "anthropic_version": "bedrock-2023-05-31", |
| 134 | "max_tokens": max_tokens, |
| 135 | "messages": [ |
| 136 | { |
| 137 | "role": "user", |
| 138 | "content": [ |
| 139 | { |
| 140 | "type": "image", |
| 141 | "source": { |
| 142 | "type": "base64", |
| 143 | "media_type": "image/jpeg", |
| 144 | "data": b64, |
| 145 | }, |
| 146 | }, |
| 147 | {"type": "text", "text": prompt}, |
| 148 | ], |
| 149 | } |
| 150 | ], |
| 151 | } |
| 152 | ) |
| 153 | |
| 154 | response = self._client.invoke_model( |
| 155 | modelId=model, |
| 156 | contentType="application/json", |
| 157 | accept="application/json", |
| 158 | body=body, |
| 159 | ) |
| 160 | |
| 161 | result = json.loads(response["body"].read()) |
| 162 | self._last_usage = { |
| 163 | "input_tokens": result.get("usage", {}).get("input_tokens", 0), |
| 164 | "output_tokens": result.get("usage", {}).get("output_tokens", 0), |
| 165 | } |
| 166 | return result.get("content", [{}])[0].get("text", "") |
| 167 | |
| 168 | def transcribe_audio( |
| 169 | self, |
| 170 | audio_path: str | Path, |
| 171 | language: Optional[str] = None, |
| 172 | model: Optional[str] = None, |
| 173 | ) -> dict: |
| 174 | raise NotImplementedError( |
| 175 | "AWS Bedrock does not support audio transcription directly. " |
| 176 | "Use Amazon Transcribe or another provider for transcription." |
| 177 | ) |
| 178 | |
| 179 | def list_models(self) -> list[ModelInfo]: |
| 180 | return list(_BEDROCK_MODELS) |
| 181 | |
| 182 | |
| 183 | ProviderRegistry.register( |
| 184 | name="bedrock", |
| 185 | provider_class=BedrockProvider, |
| 186 | env_var="AWS_ACCESS_KEY_ID", |
| 187 | model_prefixes=["bedrock/"], |
| 188 | default_models={ |
| 189 | "chat": "anthropic.claude-3-sonnet-20240229-v1:0", |
| 190 | "vision": "anthropic.claude-3-sonnet-20240229-v1:0", |
| 191 | "audio": "", |
| 192 | }, |
| 193 | ) |
| --- a/video_processor/providers/cohere_provider.py | ||
| +++ b/video_processor/providers/cohere_provider.py | ||
| @@ -0,0 +1,123 @@ | ||
| 1 | +"""Cohere provider implementation.""" | |
| 2 | + | |
| 3 | +import logging | |
| 4 | +import os | |
| 5 | +from pathlib import Path | |
| 6 | +from typing import Optional | |
| 7 | + | |
| 8 | +from dotenv import load_dotenv | |
| 9 | + | |
| 10 | +from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry | |
| 11 | + | |
| 12 | +load_dotenv() | |
| 13 | +logger = logging.getLogger(__name__) | |
| 14 | + | |
| 15 | +# Curated list of Cohere models | |
| 16 | +_COHERE_MODELS = [ | |
| 17 | + ModelInfo( | |
| 18 | + id="command-r-plus", | |
| 19 | + provider="cohere", | |
| 20 | + display_name="Command R+", | |
| 21 | + capabilities=["chat"], | |
| 22 | + ), | |
| 23 | + ModelInfo( | |
| 24 | + id="command-r", | |
| 25 | + provider="cohere", | |
| 26 | + display_name="Command R", | |
| 27 | + capabilities=["chat"], | |
| 28 | + ), | |
| 29 | + ModelInfo( | |
| 30 | + id="command-light", | |
| 31 | + provider="cohere", | |
| 32 | + display_name="Command Light", | |
| 33 | + capabilities=["chat"], | |
| 34 | + ), | |
| 35 | + ModelInfo( | |
| 36 | + id="command-nightly", | |
| 37 | + provider="cohere", | |
| 38 | + display_name="Command Nightly", | |
| 39 | + capabilities=["chat"], | |
| 40 | + ), | |
| 41 | +] | |
| 42 | + | |
| 43 | + | |
| 44 | +class CohereProvider(BaseProvider): | |
| 45 | + """Cohere provider using the cohere SDK.""" | |
| 46 | + | |
| 47 | + provider_name = "cohere" | |
| 48 | + | |
| 49 | + def __init__(self, api_key: Optional[str] = None): | |
| 50 | + try: | |
| 51 | + import cohere | |
| 52 | + except ImportError: | |
| 53 | + raise ImportError("cohere package not installed. Install with: pip install cohere") | |
| 54 | + | |
| 55 | + self._api_key = api_key or os.getenv("COHERE_API_KEY") | |
| 56 | + if not self._api_key: | |
| 57 | + raise ValueError("COHERE_API_KEY not set") | |
| 58 | + | |
| 59 | + self._client = cohere.ClientV2(api_key=self._api_key) | |
| 60 | + self._last_usage = {} | |
| 61 | + | |
| 62 | + def chat( | |
| 63 | + self, | |
| 64 | + messages: list[dict], | |
| 65 | + max_tokens: int = 4096, | |
| 66 | + temperature: float = 0.7, | |
| 67 | + model: Optional[str] = None, | |
| 68 | + ) -> str: | |
| 69 | + model = model or "command-r-plus" | |
| 70 | + | |
| 71 | + response = self._client.chat( | |
| 72 | + model=model, | |
| 73 | + messages=messages, | |
| 74 | + max_tokens=max_tokens, | |
| 75 | + temperature=temperature, | |
| 76 | + ) | |
| 77 | + | |
| 78 | + usage = getattr(response, "usage", None) | |
| 79 | + tokens = getattr(usage, "tokens", None) if usage else None | |
| 80 | + self._last_usage = { | |
| 81 | + "input_tokens": getattr(tokens, "input_tokens", 0) if tokens else 0, | |
| 82 | + "output_tokens": getattr(tokens, "output_tokens", 0) if tokens else 0, | |
| 83 | + } | |
| 84 | + return response.message.content[0].text if response.message.content else "" | |
| 85 | + | |
| 86 | + def analyze_image( | |
| 87 | + self, | |
| 88 | + image_bytes: bytes, | |
| 89 | + prompt: str, | |
| 90 | + max_tokens: int = 4096, | |
| 91 | + model: Optional[str] = None, | |
| 92 | + ) -> str: | |
| 93 | + raise NotImplementedError( | |
| 94 | + "Cohere does not currently support vision/image analysis. " | |
| 95 | + "Use OpenAI, Anthropic, or Gemini for image analysis." | |
| 96 | + ) | |
| 97 | + | |
| 98 | + def transcribe_audio( | |
| 99 | + self, | |
| 100 | + audio_path: str | Path, | |
| 101 | + language: Optional[str] = None, | |
| 102 | + model: Optional[str] = None, | |
| 103 | + ) -> dict: | |
| 104 | + raise NotImplementedError( | |
| 105 | + "Cohere does not provide a transcription API. " | |
| 106 | + "Use OpenAI Whisper or Gemini for transcription." | |
| 107 | + ) | |
| 108 | + | |
| 109 | + def list_models(self) -> list[ModelInfo]: | |
| 110 | + return list(_COHERE_MODELS) | |
| 111 | + | |
| 112 | + | |
| 113 | +ProviderRegistry.register( | |
| 114 | + name="cohere", | |
| 115 | + provider_class=CohereProvider, | |
| 116 | + env_var="COHERE_API_KEY", | |
| 117 | + model_prefixes=["command-"], | |
| 118 | + default_models={ | |
| 119 | + "chat": "command-r-plus", | |
| 120 | + "vision": "", | |
| 121 | + "audio": "", | |
| 122 | + }, | |
| 123 | +) |
| --- a/video_processor/providers/cohere_provider.py | |
| +++ b/video_processor/providers/cohere_provider.py | |
| @@ -0,0 +1,123 @@ | |
| --- a/video_processor/providers/cohere_provider.py | |
| +++ b/video_processor/providers/cohere_provider.py | |
| @@ -0,0 +1,123 @@ | |
| 1 | """Cohere provider implementation.""" |
| 2 | |
| 3 | import logging |
| 4 | import os |
| 5 | from pathlib import Path |
| 6 | from typing import Optional |
| 7 | |
| 8 | from dotenv import load_dotenv |
| 9 | |
| 10 | from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry |
| 11 | |
| 12 | load_dotenv() |
| 13 | logger = logging.getLogger(__name__) |
| 14 | |
| 15 | # Curated list of Cohere models |
| 16 | _COHERE_MODELS = [ |
| 17 | ModelInfo( |
| 18 | id="command-r-plus", |
| 19 | provider="cohere", |
| 20 | display_name="Command R+", |
| 21 | capabilities=["chat"], |
| 22 | ), |
| 23 | ModelInfo( |
| 24 | id="command-r", |
| 25 | provider="cohere", |
| 26 | display_name="Command R", |
| 27 | capabilities=["chat"], |
| 28 | ), |
| 29 | ModelInfo( |
| 30 | id="command-light", |
| 31 | provider="cohere", |
| 32 | display_name="Command Light", |
| 33 | capabilities=["chat"], |
| 34 | ), |
| 35 | ModelInfo( |
| 36 | id="command-nightly", |
| 37 | provider="cohere", |
| 38 | display_name="Command Nightly", |
| 39 | capabilities=["chat"], |
| 40 | ), |
| 41 | ] |
| 42 | |
| 43 | |
| 44 | class CohereProvider(BaseProvider): |
| 45 | """Cohere provider using the cohere SDK.""" |
| 46 | |
| 47 | provider_name = "cohere" |
| 48 | |
| 49 | def __init__(self, api_key: Optional[str] = None): |
| 50 | try: |
| 51 | import cohere |
| 52 | except ImportError: |
| 53 | raise ImportError("cohere package not installed. Install with: pip install cohere") |
| 54 | |
| 55 | self._api_key = api_key or os.getenv("COHERE_API_KEY") |
| 56 | if not self._api_key: |
| 57 | raise ValueError("COHERE_API_KEY not set") |
| 58 | |
| 59 | self._client = cohere.ClientV2(api_key=self._api_key) |
| 60 | self._last_usage = {} |
| 61 | |
| 62 | def chat( |
| 63 | self, |
| 64 | messages: list[dict], |
| 65 | max_tokens: int = 4096, |
| 66 | temperature: float = 0.7, |
| 67 | model: Optional[str] = None, |
| 68 | ) -> str: |
| 69 | model = model or "command-r-plus" |
| 70 | |
| 71 | response = self._client.chat( |
| 72 | model=model, |
| 73 | messages=messages, |
| 74 | max_tokens=max_tokens, |
| 75 | temperature=temperature, |
| 76 | ) |
| 77 | |
| 78 | usage = getattr(response, "usage", None) |
| 79 | tokens = getattr(usage, "tokens", None) if usage else None |
| 80 | self._last_usage = { |
| 81 | "input_tokens": getattr(tokens, "input_tokens", 0) if tokens else 0, |
| 82 | "output_tokens": getattr(tokens, "output_tokens", 0) if tokens else 0, |
| 83 | } |
| 84 | return response.message.content[0].text if response.message.content else "" |
| 85 | |
| 86 | def analyze_image( |
| 87 | self, |
| 88 | image_bytes: bytes, |
| 89 | prompt: str, |
| 90 | max_tokens: int = 4096, |
| 91 | model: Optional[str] = None, |
| 92 | ) -> str: |
| 93 | raise NotImplementedError( |
| 94 | "Cohere does not currently support vision/image analysis. " |
| 95 | "Use OpenAI, Anthropic, or Gemini for image analysis." |
| 96 | ) |
| 97 | |
| 98 | def transcribe_audio( |
| 99 | self, |
| 100 | audio_path: str | Path, |
| 101 | language: Optional[str] = None, |
| 102 | model: Optional[str] = None, |
| 103 | ) -> dict: |
| 104 | raise NotImplementedError( |
| 105 | "Cohere does not provide a transcription API. " |
| 106 | "Use OpenAI Whisper or Gemini for transcription." |
| 107 | ) |
| 108 | |
| 109 | def list_models(self) -> list[ModelInfo]: |
| 110 | return list(_COHERE_MODELS) |
| 111 | |
| 112 | |
| 113 | ProviderRegistry.register( |
| 114 | name="cohere", |
| 115 | provider_class=CohereProvider, |
| 116 | env_var="COHERE_API_KEY", |
| 117 | model_prefixes=["command-"], |
| 118 | default_models={ |
| 119 | "chat": "command-r-plus", |
| 120 | "vision": "", |
| 121 | "audio": "", |
| 122 | }, |
| 123 | ) |
| --- a/video_processor/providers/huggingface_provider.py | ||
| +++ b/video_processor/providers/huggingface_provider.py | ||
| @@ -0,0 +1,187 @@ | ||
| 1 | +"""Hugging Face Inference API provider implementation.""" | |
| 2 | + | |
| 3 | +import base64 | |
| 4 | +import logging | |
| 5 | +import os | |
| 6 | +from pathlib import Path | |
| 7 | +from typing import Optional | |
| 8 | + | |
| 9 | +from dotenv import load_dotenv | |
| 10 | + | |
| 11 | +from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry | |
| 12 | + | |
| 13 | +load_dotenv() | |
| 14 | +logger = logging.getLogger(__name__) | |
| 15 | + | |
| 16 | +# Curated list of popular HF Inference models | |
| 17 | +_HF_MODELS = [ | |
| 18 | + ModelInfo( | |
| 19 | + id="meta-llama/Llama-3.1-70B-Instruct", | |
| 20 | + provider="huggingface", | |
| 21 | + display_name="Llama 3.1 70B Instruct", | |
| 22 | + capabilities=["chat"], | |
| 23 | + ), | |
| 24 | + ModelInfo( | |
| 25 | + id="meta-llama/Llama-3.1-8B-Instruct", | |
| 26 | + provider="huggingface", | |
| 27 | + display_name="Llama 3.1 8B Instruct", | |
| 28 | + capabilities=["chat"], | |
| 29 | + ), | |
| 30 | + ModelInfo( | |
| 31 | + id="mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| 32 | + provider="huggingface", | |
| 33 | + display_name="Mixtral 8x7B Instruct", | |
| 34 | + capabilities=["chat"], | |
| 35 | + ), | |
| 36 | + ModelInfo( | |
| 37 | + id="microsoft/Phi-3-mini-4k-instruct", | |
| 38 | + provider="huggingface", | |
| 39 | + display_name="Phi-3 Mini 4K Instruct", | |
| 40 | + capabilities=["chat"], | |
| 41 | + ), | |
| 42 | + ModelInfo( | |
| 43 | + id="llava-hf/llava-v1.6-mistral-7b-hf", | |
| 44 | + provider="huggingface", | |
| 45 | + display_name="LLaVA v1.6 Mistral 7B", | |
| 46 | + capabilities=["chat", "vision"], | |
| 47 | + ), | |
| 48 | + ModelInfo( | |
| 49 | + id="openai/whisper-large-v3", | |
| 50 | + provider="huggingface", | |
| 51 | + display_name="Whisper Large v3", | |
| 52 | + capabilities=["audio"], | |
| 53 | + ), | |
| 54 | +] | |
| 55 | + | |
| 56 | + | |
| 57 | +class HuggingFaceProvider(BaseProvider): | |
| 58 | + """Hugging Face Inference API provider using huggingface_hub.""" | |
| 59 | + | |
| 60 | + provider_name = "huggingface" | |
| 61 | + | |
| 62 | + def __init__(self, token: Optional[str] = None): | |
| 63 | + try: | |
| 64 | + from huggingface_hub import InferenceClient | |
| 65 | + except ImportError: | |
| 66 | + raise ImportError( | |
| 67 | + "huggingface_hub package not installed. Install with: pip install huggingface_hub" | |
| 68 | + ) | |
| 69 | + | |
| 70 | + self._token = token or os.getenv("HF_TOKEN") | |
| 71 | + if not self._token: | |
| 72 | + raise ValueError("HF_TOKEN not set") | |
| 73 | + | |
| 74 | + self._client = InferenceClient(token=self._token) | |
| 75 | + self._last_usage = {} | |
| 76 | + | |
| 77 | + def chat( | |
| 78 | + self, | |
| 79 | + messages: list[dict], | |
| 80 | + max_tokens: int = 4096, | |
| 81 | + temperature: float = 0.7, | |
| 82 | + model: Optional[str] = None, | |
| 83 | + ) -> str: | |
| 84 | + model = model or "meta-llama/Llama-3.1-70B-Instruct" | |
| 85 | + if model.startswith("hf/"): | |
| 86 | + model = model[len("hf/") :] | |
| 87 | + | |
| 88 | + response = self._client.chat_completion( | |
| 89 | + model=model, | |
| 90 | + messages=messages, | |
| 91 | + max_tokens=max_tokens, | |
| 92 | + temperature=temperature, | |
| 93 | + ) | |
| 94 | + | |
| 95 | + usage = getattr(response, "usage", None) | |
| 96 | + self._last_usage = { | |
| 97 | + "input_tokens": getattr(usage, "prompt_tokens", 0) if usage else 0, | |
| 98 | + "output_tokens": getattr(usage, "completion_tokens", 0) if usage else 0, | |
| 99 | + } | |
| 100 | + return response.choices[0].message.content or "" | |
| 101 | + | |
| 102 | + def analyze_image( | |
| 103 | + self, | |
| 104 | + image_bytes: bytes, | |
| 105 | + prompt: str, | |
| 106 | + max_tokens: int = 4096, | |
| 107 | + model: Optional[str] = None, | |
| 108 | + ) -> str: | |
| 109 | + model = model or "llava-hf/llava-v1.6-mistral-7b-hf" | |
| 110 | + if model.startswith("hf/"): | |
| 111 | + model = model[len("hf/") :] | |
| 112 | + | |
| 113 | + b64 = base64.b64encode(image_bytes).decode() | |
| 114 | + | |
| 115 | + response = self._client.chat_completion( | |
| 116 | + model=model, | |
| 117 | + messages=[ | |
| 118 | + { | |
| 119 | + "role": "user", | |
| 120 | + "content": [ | |
| 121 | + {"type": "text", "text": prompt}, | |
| 122 | + { | |
| 123 | + "type": "image_url", | |
| 124 | + "image_url": {"url": f"data:image/jpeg;base64,{b64}"}, | |
| 125 | + }, | |
| 126 | + ], | |
| 127 | + } | |
| 128 | + ], | |
| 129 | + max_tokens=max_tokens, | |
| 130 | + ) | |
| 131 | + | |
| 132 | + usage = getattr(response, "usage", None) | |
| 133 | + self._last_usage = { | |
| 134 | + "input_tokens": getattr(usage, "prompt_tokens", 0) if usage else 0, | |
| 135 | + "output_tokens": getattr(usage, "completion_tokens", 0) if usage else 0, | |
| 136 | + } | |
| 137 | + return response.choices[0].message.content or "" | |
| 138 | + | |
| 139 | + def transcribe_audio( | |
| 140 | + self, | |
| 141 | + audio_path: str | Path, | |
| 142 | + language: Optional[str] = None, | |
| 143 | + model: Optional[str] = None, | |
| 144 | + ) -> dict: | |
| 145 | + model = model or "openai/whisper-large-v3" | |
| 146 | + if model.startswith("hf/"): | |
| 147 | + model = model[len("hf/") :] | |
| 148 | + | |
| 149 | + audio_path = Path(audio_path) | |
| 150 | + audio_bytes = audio_path.read_bytes() | |
| 151 | + | |
| 152 | + result = self._client.automatic_speech_recognition( | |
| 153 | + audio=audio_bytes, | |
| 154 | + model=model, | |
| 155 | + ) | |
| 156 | + | |
| 157 | + text = result.text if hasattr(result, "text") else str(result) | |
| 158 | + | |
| 159 | + self._last_usage = { | |
| 160 | + "input_tokens": 0, | |
| 161 | + "output_tokens": 0, | |
| 162 | + } | |
| 163 | + | |
| 164 | + return { | |
| 165 | + "text": text, | |
| 166 | + "segments": [], | |
| 167 | + "language": language, | |
| 168 | + "duration": None, | |
| 169 | + "provider": "huggingface", | |
| 170 | + "model": model, | |
| 171 | + } | |
| 172 | + | |
| 173 | + def list_models(self) -> list[ModelInfo]: | |
| 174 | + return list(_HF_MODELS) | |
| 175 | + | |
| 176 | + | |
| 177 | +ProviderRegistry.register( | |
| 178 | + name="huggingface", | |
| 179 | + provider_class=HuggingFaceProvider, | |
| 180 | + env_var="HF_TOKEN", | |
| 181 | + model_prefixes=["hf/"], | |
| 182 | + default_models={ | |
| 183 | + "chat": "meta-llama/Llama-3.1-70B-Instruct", | |
| 184 | + "vision": "llava-hf/llava-v1.6-mistral-7b-hf", | |
| 185 | + "audio": "openai/whisper-large-v3", | |
| 186 | + }, | |
| 187 | +) |
| --- a/video_processor/providers/huggingface_provider.py | |
| +++ b/video_processor/providers/huggingface_provider.py | |
| @@ -0,0 +1,187 @@ | |
| --- a/video_processor/providers/huggingface_provider.py | |
| +++ b/video_processor/providers/huggingface_provider.py | |
| @@ -0,0 +1,187 @@ | |
| 1 | """Hugging Face Inference API provider implementation.""" |
| 2 | |
| 3 | import base64 |
| 4 | import logging |
| 5 | import os |
| 6 | from pathlib import Path |
| 7 | from typing import Optional |
| 8 | |
| 9 | from dotenv import load_dotenv |
| 10 | |
| 11 | from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry |
| 12 | |
| 13 | load_dotenv() |
| 14 | logger = logging.getLogger(__name__) |
| 15 | |
| 16 | # Curated list of popular HF Inference models |
| 17 | _HF_MODELS = [ |
| 18 | ModelInfo( |
| 19 | id="meta-llama/Llama-3.1-70B-Instruct", |
| 20 | provider="huggingface", |
| 21 | display_name="Llama 3.1 70B Instruct", |
| 22 | capabilities=["chat"], |
| 23 | ), |
| 24 | ModelInfo( |
| 25 | id="meta-llama/Llama-3.1-8B-Instruct", |
| 26 | provider="huggingface", |
| 27 | display_name="Llama 3.1 8B Instruct", |
| 28 | capabilities=["chat"], |
| 29 | ), |
| 30 | ModelInfo( |
| 31 | id="mistralai/Mixtral-8x7B-Instruct-v0.1", |
| 32 | provider="huggingface", |
| 33 | display_name="Mixtral 8x7B Instruct", |
| 34 | capabilities=["chat"], |
| 35 | ), |
| 36 | ModelInfo( |
| 37 | id="microsoft/Phi-3-mini-4k-instruct", |
| 38 | provider="huggingface", |
| 39 | display_name="Phi-3 Mini 4K Instruct", |
| 40 | capabilities=["chat"], |
| 41 | ), |
| 42 | ModelInfo( |
| 43 | id="llava-hf/llava-v1.6-mistral-7b-hf", |
| 44 | provider="huggingface", |
| 45 | display_name="LLaVA v1.6 Mistral 7B", |
| 46 | capabilities=["chat", "vision"], |
| 47 | ), |
| 48 | ModelInfo( |
| 49 | id="openai/whisper-large-v3", |
| 50 | provider="huggingface", |
| 51 | display_name="Whisper Large v3", |
| 52 | capabilities=["audio"], |
| 53 | ), |
| 54 | ] |
| 55 | |
| 56 | |
| 57 | class HuggingFaceProvider(BaseProvider): |
| 58 | """Hugging Face Inference API provider using huggingface_hub.""" |
| 59 | |
| 60 | provider_name = "huggingface" |
| 61 | |
| 62 | def __init__(self, token: Optional[str] = None): |
| 63 | try: |
| 64 | from huggingface_hub import InferenceClient |
| 65 | except ImportError: |
| 66 | raise ImportError( |
| 67 | "huggingface_hub package not installed. Install with: pip install huggingface_hub" |
| 68 | ) |
| 69 | |
| 70 | self._token = token or os.getenv("HF_TOKEN") |
| 71 | if not self._token: |
| 72 | raise ValueError("HF_TOKEN not set") |
| 73 | |
| 74 | self._client = InferenceClient(token=self._token) |
| 75 | self._last_usage = {} |
| 76 | |
| 77 | def chat( |
| 78 | self, |
| 79 | messages: list[dict], |
| 80 | max_tokens: int = 4096, |
| 81 | temperature: float = 0.7, |
| 82 | model: Optional[str] = None, |
| 83 | ) -> str: |
| 84 | model = model or "meta-llama/Llama-3.1-70B-Instruct" |
| 85 | if model.startswith("hf/"): |
| 86 | model = model[len("hf/") :] |
| 87 | |
| 88 | response = self._client.chat_completion( |
| 89 | model=model, |
| 90 | messages=messages, |
| 91 | max_tokens=max_tokens, |
| 92 | temperature=temperature, |
| 93 | ) |
| 94 | |
| 95 | usage = getattr(response, "usage", None) |
| 96 | self._last_usage = { |
| 97 | "input_tokens": getattr(usage, "prompt_tokens", 0) if usage else 0, |
| 98 | "output_tokens": getattr(usage, "completion_tokens", 0) if usage else 0, |
| 99 | } |
| 100 | return response.choices[0].message.content or "" |
| 101 | |
| 102 | def analyze_image( |
| 103 | self, |
| 104 | image_bytes: bytes, |
| 105 | prompt: str, |
| 106 | max_tokens: int = 4096, |
| 107 | model: Optional[str] = None, |
| 108 | ) -> str: |
| 109 | model = model or "llava-hf/llava-v1.6-mistral-7b-hf" |
| 110 | if model.startswith("hf/"): |
| 111 | model = model[len("hf/") :] |
| 112 | |
| 113 | b64 = base64.b64encode(image_bytes).decode() |
| 114 | |
| 115 | response = self._client.chat_completion( |
| 116 | model=model, |
| 117 | messages=[ |
| 118 | { |
| 119 | "role": "user", |
| 120 | "content": [ |
| 121 | {"type": "text", "text": prompt}, |
| 122 | { |
| 123 | "type": "image_url", |
| 124 | "image_url": {"url": f"data:image/jpeg;base64,{b64}"}, |
| 125 | }, |
| 126 | ], |
| 127 | } |
| 128 | ], |
| 129 | max_tokens=max_tokens, |
| 130 | ) |
| 131 | |
| 132 | usage = getattr(response, "usage", None) |
| 133 | self._last_usage = { |
| 134 | "input_tokens": getattr(usage, "prompt_tokens", 0) if usage else 0, |
| 135 | "output_tokens": getattr(usage, "completion_tokens", 0) if usage else 0, |
| 136 | } |
| 137 | return response.choices[0].message.content or "" |
| 138 | |
| 139 | def transcribe_audio( |
| 140 | self, |
| 141 | audio_path: str | Path, |
| 142 | language: Optional[str] = None, |
| 143 | model: Optional[str] = None, |
| 144 | ) -> dict: |
| 145 | model = model or "openai/whisper-large-v3" |
| 146 | if model.startswith("hf/"): |
| 147 | model = model[len("hf/") :] |
| 148 | |
| 149 | audio_path = Path(audio_path) |
| 150 | audio_bytes = audio_path.read_bytes() |
| 151 | |
| 152 | result = self._client.automatic_speech_recognition( |
| 153 | audio=audio_bytes, |
| 154 | model=model, |
| 155 | ) |
| 156 | |
| 157 | text = result.text if hasattr(result, "text") else str(result) |
| 158 | |
| 159 | self._last_usage = { |
| 160 | "input_tokens": 0, |
| 161 | "output_tokens": 0, |
| 162 | } |
| 163 | |
| 164 | return { |
| 165 | "text": text, |
| 166 | "segments": [], |
| 167 | "language": language, |
| 168 | "duration": None, |
| 169 | "provider": "huggingface", |
| 170 | "model": model, |
| 171 | } |
| 172 | |
| 173 | def list_models(self) -> list[ModelInfo]: |
| 174 | return list(_HF_MODELS) |
| 175 | |
| 176 | |
| 177 | ProviderRegistry.register( |
| 178 | name="huggingface", |
| 179 | provider_class=HuggingFaceProvider, |
| 180 | env_var="HF_TOKEN", |
| 181 | model_prefixes=["hf/"], |
| 182 | default_models={ |
| 183 | "chat": "meta-llama/Llama-3.1-70B-Instruct", |
| 184 | "vision": "llava-hf/llava-v1.6-mistral-7b-hf", |
| 185 | "audio": "openai/whisper-large-v3", |
| 186 | }, |
| 187 | ) |
| --- a/video_processor/providers/litellm_provider.py | ||
| +++ b/video_processor/providers/litellm_provider.py | ||
| @@ -0,0 +1,171 @@ | ||
| 1 | +"""LiteLLM universal proxy provider implementation.""" | |
| 2 | + | |
| 3 | +import base64 | |
| 4 | +import logging | |
| 5 | +from pathlib import Path | |
| 6 | +from typing import Optional | |
| 7 | + | |
| 8 | +from dotenv import load_dotenv | |
| 9 | + | |
| 10 | +from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry | |
| 11 | + | |
| 12 | +load_dotenv() | |
| 13 | +logger = logging.getLogger(__name__) | |
| 14 | + | |
| 15 | + | |
| 16 | +class LiteLLMProvider(BaseProvider): | |
| 17 | + """LiteLLM universal proxy provider. | |
| 18 | + | |
| 19 | + LiteLLM supports 100+ LLM providers through a unified interface. | |
| 20 | + It reads provider API keys from environment variables automatically | |
| 21 | + (e.g. OPENAI_API_KEY, ANTHROPIC_API_KEY, etc.). | |
| 22 | + """ | |
| 23 | + | |
| 24 | + provider_name = "litellm" | |
| 25 | + | |
| 26 | + def __init__(self): | |
| 27 | + try: | |
| 28 | + import litellm # noqa: F401 | |
| 29 | + except ImportError: | |
| 30 | + raise ImportError("litellm package not installed. Install with: pip install litellm") | |
| 31 | + | |
| 32 | + self._litellm = litellm | |
| 33 | + self._last_usage = {} | |
| 34 | + | |
| 35 | + def chat( | |
| 36 | + self, | |
| 37 | + messages: list[dict], | |
| 38 | + max_tokens: int = 4096, | |
| 39 | + temperature: float = 0.7, | |
| 40 | + model: Optional[str] = None, | |
| 41 | + ) -> str: | |
| 42 | + if not model: | |
| 43 | + raise ValueError( | |
| 44 | + "LiteLLM requires an explicit model in provider/model format " | |
| 45 | + "(e.g. 'openai/gpt-4o', 'anthropic/claude-3-sonnet-20240229')" | |
| 46 | + ) | |
| 47 | + | |
| 48 | + response = self._litellm.completion( | |
| 49 | + model=model, | |
| 50 | + messages=messages, | |
| 51 | + max_tokens=max_tokens, | |
| 52 | + temperature=temperature, | |
| 53 | + ) | |
| 54 | + | |
| 55 | + usage = getattr(response, "usage", None) | |
| 56 | + self._last_usage = { | |
| 57 | + "input_tokens": getattr(usage, "prompt_tokens", 0) if usage else 0, | |
| 58 | + "output_tokens": getattr(usage, "completion_tokens", 0) if usage else 0, | |
| 59 | + } | |
| 60 | + return response.choices[0].message.content or "" | |
| 61 | + | |
| 62 | + def analyze_image( | |
| 63 | + self, | |
| 64 | + image_bytes: bytes, | |
| 65 | + prompt: str, | |
| 66 | + max_tokens: int = 4096, | |
| 67 | + model: Optional[str] = None, | |
| 68 | + ) -> str: | |
| 69 | + if not model: | |
| 70 | + raise ValueError( | |
| 71 | + "LiteLLM requires an explicit model for image analysis " | |
| 72 | + "(e.g. 'openai/gpt-4o', 'anthropic/claude-3-sonnet-20240229')" | |
| 73 | + ) | |
| 74 | + | |
| 75 | + b64 = base64.b64encode(image_bytes).decode() | |
| 76 | + | |
| 77 | + response = self._litellm.completion( | |
| 78 | + model=model, | |
| 79 | + messages=[ | |
| 80 | + { | |
| 81 | + "role": "user", | |
| 82 | + "content": [ | |
| 83 | + {"type": "text", "text": prompt}, | |
| 84 | + { | |
| 85 | + "type": "image_url", | |
| 86 | + "image_url": {"url": f"data:image/jpeg;base64,{b64}"}, | |
| 87 | + }, | |
| 88 | + ], | |
| 89 | + } | |
| 90 | + ], | |
| 91 | + max_tokens=max_tokens, | |
| 92 | + ) | |
| 93 | + | |
| 94 | + usage = getattr(response, "usage", None) | |
| 95 | + self._last_usage = { | |
| 96 | + "input_tokens": getattr(usage, "prompt_tokens", 0) if usage else 0, | |
| 97 | + "output_tokens": getattr(usage, "completion_tokens", 0) if usage else 0, | |
| 98 | + } | |
| 99 | + return response.choices[0].message.content or "" | |
| 100 | + | |
| 101 | + def transcribe_audio( | |
| 102 | + self, | |
| 103 | + audio_path: str | Path, | |
| 104 | + language: Optional[str] = None, | |
| 105 | + model: Optional[str] = None, | |
| 106 | + ) -> dict: | |
| 107 | + model = model or "whisper-1" | |
| 108 | + | |
| 109 | + try: | |
| 110 | + with open(audio_path, "rb") as f: | |
| 111 | + response = self._litellm.transcription( | |
| 112 | + model=model, | |
| 113 | + file=f, | |
| 114 | + language=language, | |
| 115 | + ) | |
| 116 | + | |
| 117 | + text = getattr(response, "text", str(response)) | |
| 118 | + self._last_usage = { | |
| 119 | + "input_tokens": 0, | |
| 120 | + "output_tokens": 0, | |
| 121 | + } | |
| 122 | + | |
| 123 | + return { | |
| 124 | + "text": text, | |
| 125 | + "segments": [], | |
| 126 | + "language": language, | |
| 127 | + "duration": None, | |
| 128 | + "provider": "litellm", | |
| 129 | + "model": model, | |
| 130 | + } | |
| 131 | + except Exception: | |
| 132 | + raise NotImplementedError( | |
| 133 | + "Audio transcription failed via LiteLLM. " | |
| 134 | + "Ensure the underlying provider supports transcription." | |
| 135 | + ) | |
| 136 | + | |
| 137 | + def list_models(self) -> list[ModelInfo]: | |
| 138 | + try: | |
| 139 | + model_list = getattr(self._litellm, "model_list", None) | |
| 140 | + if model_list: | |
| 141 | + return [ | |
| 142 | + ModelInfo( | |
| 143 | + id=m if isinstance(m, str) else str(m), | |
| 144 | + provider="litellm", | |
| 145 | + display_name=m if isinstance(m, str) else str(m), | |
| 146 | + capabilities=["chat"], | |
| 147 | + ) | |
| 148 | + for m in model_list | |
| 149 | + ] | |
| 150 | + except Exception as e: | |
| 151 | + logger.warning(f"Failed to list LiteLLM models: {e}") | |
| 152 | + return [] | |
| 153 | + | |
| 154 | + | |
| 155 | +# Only register if litellm is importable | |
| 156 | +try: | |
| 157 | + import litellm # noqa: F401 | |
| 158 | + | |
| 159 | + ProviderRegistry.register( | |
| 160 | + name="litellm", | |
| 161 | + provider_class=LiteLLMProvider, | |
| 162 | + env_var="", | |
| 163 | + model_prefixes=[], | |
| 164 | + default_models={ | |
| 165 | + "chat": "", | |
| 166 | + "vision": "", | |
| 167 | + "audio": "", | |
| 168 | + }, | |
| 169 | + ) | |
| 170 | +except ImportError: | |
| 171 | + pass |
| --- a/video_processor/providers/litellm_provider.py | |
| +++ b/video_processor/providers/litellm_provider.py | |
| @@ -0,0 +1,171 @@ | |
| --- a/video_processor/providers/litellm_provider.py | |
| +++ b/video_processor/providers/litellm_provider.py | |
| @@ -0,0 +1,171 @@ | |
| 1 | """LiteLLM universal proxy provider implementation.""" |
| 2 | |
| 3 | import base64 |
| 4 | import logging |
| 5 | from pathlib import Path |
| 6 | from typing import Optional |
| 7 | |
| 8 | from dotenv import load_dotenv |
| 9 | |
| 10 | from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry |
| 11 | |
| 12 | load_dotenv() |
| 13 | logger = logging.getLogger(__name__) |
| 14 | |
| 15 | |
| 16 | class LiteLLMProvider(BaseProvider): |
| 17 | """LiteLLM universal proxy provider. |
| 18 | |
| 19 | LiteLLM supports 100+ LLM providers through a unified interface. |
| 20 | It reads provider API keys from environment variables automatically |
| 21 | (e.g. OPENAI_API_KEY, ANTHROPIC_API_KEY, etc.). |
| 22 | """ |
| 23 | |
| 24 | provider_name = "litellm" |
| 25 | |
| 26 | def __init__(self): |
| 27 | try: |
| 28 | import litellm # noqa: F401 |
| 29 | except ImportError: |
| 30 | raise ImportError("litellm package not installed. Install with: pip install litellm") |
| 31 | |
| 32 | self._litellm = litellm |
| 33 | self._last_usage = {} |
| 34 | |
| 35 | def chat( |
| 36 | self, |
| 37 | messages: list[dict], |
| 38 | max_tokens: int = 4096, |
| 39 | temperature: float = 0.7, |
| 40 | model: Optional[str] = None, |
| 41 | ) -> str: |
| 42 | if not model: |
| 43 | raise ValueError( |
| 44 | "LiteLLM requires an explicit model in provider/model format " |
| 45 | "(e.g. 'openai/gpt-4o', 'anthropic/claude-3-sonnet-20240229')" |
| 46 | ) |
| 47 | |
| 48 | response = self._litellm.completion( |
| 49 | model=model, |
| 50 | messages=messages, |
| 51 | max_tokens=max_tokens, |
| 52 | temperature=temperature, |
| 53 | ) |
| 54 | |
| 55 | usage = getattr(response, "usage", None) |
| 56 | self._last_usage = { |
| 57 | "input_tokens": getattr(usage, "prompt_tokens", 0) if usage else 0, |
| 58 | "output_tokens": getattr(usage, "completion_tokens", 0) if usage else 0, |
| 59 | } |
| 60 | return response.choices[0].message.content or "" |
| 61 | |
| 62 | def analyze_image( |
| 63 | self, |
| 64 | image_bytes: bytes, |
| 65 | prompt: str, |
| 66 | max_tokens: int = 4096, |
| 67 | model: Optional[str] = None, |
| 68 | ) -> str: |
| 69 | if not model: |
| 70 | raise ValueError( |
| 71 | "LiteLLM requires an explicit model for image analysis " |
| 72 | "(e.g. 'openai/gpt-4o', 'anthropic/claude-3-sonnet-20240229')" |
| 73 | ) |
| 74 | |
| 75 | b64 = base64.b64encode(image_bytes).decode() |
| 76 | |
| 77 | response = self._litellm.completion( |
| 78 | model=model, |
| 79 | messages=[ |
| 80 | { |
| 81 | "role": "user", |
| 82 | "content": [ |
| 83 | {"type": "text", "text": prompt}, |
| 84 | { |
| 85 | "type": "image_url", |
| 86 | "image_url": {"url": f"data:image/jpeg;base64,{b64}"}, |
| 87 | }, |
| 88 | ], |
| 89 | } |
| 90 | ], |
| 91 | max_tokens=max_tokens, |
| 92 | ) |
| 93 | |
| 94 | usage = getattr(response, "usage", None) |
| 95 | self._last_usage = { |
| 96 | "input_tokens": getattr(usage, "prompt_tokens", 0) if usage else 0, |
| 97 | "output_tokens": getattr(usage, "completion_tokens", 0) if usage else 0, |
| 98 | } |
| 99 | return response.choices[0].message.content or "" |
| 100 | |
| 101 | def transcribe_audio( |
| 102 | self, |
| 103 | audio_path: str | Path, |
| 104 | language: Optional[str] = None, |
| 105 | model: Optional[str] = None, |
| 106 | ) -> dict: |
| 107 | model = model or "whisper-1" |
| 108 | |
| 109 | try: |
| 110 | with open(audio_path, "rb") as f: |
| 111 | response = self._litellm.transcription( |
| 112 | model=model, |
| 113 | file=f, |
| 114 | language=language, |
| 115 | ) |
| 116 | |
| 117 | text = getattr(response, "text", str(response)) |
| 118 | self._last_usage = { |
| 119 | "input_tokens": 0, |
| 120 | "output_tokens": 0, |
| 121 | } |
| 122 | |
| 123 | return { |
| 124 | "text": text, |
| 125 | "segments": [], |
| 126 | "language": language, |
| 127 | "duration": None, |
| 128 | "provider": "litellm", |
| 129 | "model": model, |
| 130 | } |
| 131 | except Exception: |
| 132 | raise NotImplementedError( |
| 133 | "Audio transcription failed via LiteLLM. " |
| 134 | "Ensure the underlying provider supports transcription." |
| 135 | ) |
| 136 | |
| 137 | def list_models(self) -> list[ModelInfo]: |
| 138 | try: |
| 139 | model_list = getattr(self._litellm, "model_list", None) |
| 140 | if model_list: |
| 141 | return [ |
| 142 | ModelInfo( |
| 143 | id=m if isinstance(m, str) else str(m), |
| 144 | provider="litellm", |
| 145 | display_name=m if isinstance(m, str) else str(m), |
| 146 | capabilities=["chat"], |
| 147 | ) |
| 148 | for m in model_list |
| 149 | ] |
| 150 | except Exception as e: |
| 151 | logger.warning(f"Failed to list LiteLLM models: {e}") |
| 152 | return [] |
| 153 | |
| 154 | |
| 155 | # Only register if litellm is importable |
| 156 | try: |
| 157 | import litellm # noqa: F401 |
| 158 | |
| 159 | ProviderRegistry.register( |
| 160 | name="litellm", |
| 161 | provider_class=LiteLLMProvider, |
| 162 | env_var="", |
| 163 | model_prefixes=[], |
| 164 | default_models={ |
| 165 | "chat": "", |
| 166 | "vision": "", |
| 167 | "audio": "", |
| 168 | }, |
| 169 | ) |
| 170 | except ImportError: |
| 171 | pass |
| --- a/video_processor/providers/mistral_provider.py | ||
| +++ b/video_processor/providers/mistral_provider.py | ||
| @@ -0,0 +1,167 @@ | ||
| 1 | +"""Mistral AI provider implementation.""" | |
| 2 | + | |
| 3 | +import base64 | |
| 4 | +import logging | |
| 5 | +import os | |
| 6 | +from pathlib import Path | |
| 7 | +from typing import Optional | |
| 8 | + | |
| 9 | +from dotenv import load_dotenv | |
| 10 | + | |
| 11 | +from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry | |
| 12 | + | |
| 13 | +load_dotenv() | |
| 14 | +logger = logging.getLogger(__name__) | |
| 15 | + | |
| 16 | +# Curated list of Mistral models | |
| 17 | +_MISTRAL_MODELS = [ | |
| 18 | + ModelInfo( | |
| 19 | + id="mistral-large-latest", | |
| 20 | + provider="mistral", | |
| 21 | + display_name="Mistral Large", | |
| 22 | + capabilities=["chat"], | |
| 23 | + ), | |
| 24 | + ModelInfo( | |
| 25 | + id="mistral-medium-latest", | |
| 26 | + provider="mistral", | |
| 27 | + display_name="Mistral Medium", | |
| 28 | + capabilities=["chat"], | |
| 29 | + ), | |
| 30 | + ModelInfo( | |
| 31 | + id="mistral-small-latest", | |
| 32 | + provider="mistral", | |
| 33 | + display_name="Mistral Small", | |
| 34 | + capabilities=["chat"], | |
| 35 | + ), | |
| 36 | + ModelInfo( | |
| 37 | + id="open-mistral-nemo", | |
| 38 | + provider="mistral", | |
| 39 | + display_name="Mistral Nemo", | |
| 40 | + capabilities=["chat"], | |
| 41 | + ), | |
| 42 | + ModelInfo( | |
| 43 | + id="pixtral-large-latest", | |
| 44 | + provider="mistral", | |
| 45 | + display_name="Pixtral Large", | |
| 46 | + capabilities=["chat", "vision"], | |
| 47 | + ), | |
| 48 | + ModelInfo( | |
| 49 | + id="pixtral-12b-2409", | |
| 50 | + provider="mistral", | |
| 51 | + display_name="Pixtral 12B", | |
| 52 | + capabilities=["chat", "vision"], | |
| 53 | + ), | |
| 54 | + ModelInfo( | |
| 55 | + id="codestral-latest", | |
| 56 | + provider="mistral", | |
| 57 | + display_name="Codestral", | |
| 58 | + capabilities=["chat"], | |
| 59 | + ), | |
| 60 | +] | |
| 61 | + | |
| 62 | + | |
| 63 | +class MistralProvider(BaseProvider): | |
| 64 | + """Mistral AI provider using the mistralai SDK.""" | |
| 65 | + | |
| 66 | + provider_name = "mistral" | |
| 67 | + | |
| 68 | + def __init__(self, api_key: Optional[str] = None): | |
| 69 | + try: | |
| 70 | + from mistralai import Mistral | |
| 71 | + except ImportError: | |
| 72 | + raise ImportError( | |
| 73 | + "mistralai package not installed. Install with: pip install mistralai" | |
| 74 | + ) | |
| 75 | + | |
| 76 | + self._api_key = api_key or os.getenv("MISTRAL_API_KEY") | |
| 77 | + if not self._api_key: | |
| 78 | + raise ValueError("MISTRAL_API_KEY not set") | |
| 79 | + | |
| 80 | + self._client = Mistral(api_key=self._api_key) | |
| 81 | + self._last_usage = {} | |
| 82 | + | |
| 83 | + def chat( | |
| 84 | + self, | |
| 85 | + messages: list[dict], | |
| 86 | + max_tokens: int = 4096, | |
| 87 | + temperature: float = 0.7, | |
| 88 | + model: Optional[str] = None, | |
| 89 | + ) -> str: | |
| 90 | + model = model or "mistral-large-latest" | |
| 91 | + | |
| 92 | + response = self._client.chat.complete( | |
| 93 | + model=model, | |
| 94 | + messages=messages, | |
| 95 | + max_tokens=max_tokens, | |
| 96 | + temperature=temperature, | |
| 97 | + ) | |
| 98 | + | |
| 99 | + self._last_usage = { | |
| 100 | + "input_tokens": getattr(response.usage, "prompt_tokens", 0) if response.usage else 0, | |
| 101 | + "output_tokens": getattr(response.usage, "completion_tokens", 0) | |
| 102 | + if response.usage | |
| 103 | + else 0, | |
| 104 | + } | |
| 105 | + return response.choices[0].message.content or "" | |
| 106 | + | |
| 107 | + def analyze_image( | |
| 108 | + self, | |
| 109 | + image_bytes: bytes, | |
| 110 | + prompt: str, | |
| 111 | + max_tokens: int = 4096, | |
| 112 | + model: Optional[str] = None, | |
| 113 | + ) -> str: | |
| 114 | + model = model or "pixtral-large-latest" | |
| 115 | + b64 = base64.b64encode(image_bytes).decode() | |
| 116 | + | |
| 117 | + response = self._client.chat.complete( | |
| 118 | + model=model, | |
| 119 | + messages=[ | |
| 120 | + { | |
| 121 | + "role": "user", | |
| 122 | + "content": [ | |
| 123 | + {"type": "text", "text": prompt}, | |
| 124 | + { | |
| 125 | + "type": "image_url", | |
| 126 | + "image_url": {"url": f"data:image/jpeg;base64,{b64}"}, | |
| 127 | + }, | |
| 128 | + ], | |
| 129 | + } | |
| 130 | + ], | |
| 131 | + max_tokens=max_tokens, | |
| 132 | + ) | |
| 133 | + | |
| 134 | + self._last_usage = { | |
| 135 | + "input_tokens": getattr(response.usage, "prompt_tokens", 0) if response.usage else 0, | |
| 136 | + "output_tokens": getattr(response.usage, "completion_tokens", 0) | |
| 137 | + if response.usage | |
| 138 | + else 0, | |
| 139 | + } | |
| 140 | + return response.choices[0].message.content or "" | |
| 141 | + | |
| 142 | + def transcribe_audio( | |
| 143 | + self, | |
| 144 | + audio_path: str | Path, | |
| 145 | + language: Optional[str] = None, | |
| 146 | + model: Optional[str] = None, | |
| 147 | + ) -> dict: | |
| 148 | + raise NotImplementedError( | |
| 149 | + "Mistral does not provide a transcription API. " | |
| 150 | + "Use OpenAI Whisper or Gemini for transcription." | |
| 151 | + ) | |
| 152 | + | |
| 153 | + def list_models(self) -> list[ModelInfo]: | |
| 154 | + return list(_MISTRAL_MODELS) | |
| 155 | + | |
| 156 | + | |
| 157 | +ProviderRegistry.register( | |
| 158 | + name="mistral", | |
| 159 | + provider_class=MistralProvider, | |
| 160 | + env_var="MISTRAL_API_KEY", | |
| 161 | + model_prefixes=["mistral-", "pixtral-", "codestral-", "open-mistral-"], | |
| 162 | + default_models={ | |
| 163 | + "chat": "mistral-large-latest", | |
| 164 | + "vision": "pixtral-large-latest", | |
| 165 | + "audio": "", | |
| 166 | + }, | |
| 167 | +) |
| --- a/video_processor/providers/mistral_provider.py | |
| +++ b/video_processor/providers/mistral_provider.py | |
| @@ -0,0 +1,167 @@ | |
| --- a/video_processor/providers/mistral_provider.py | |
| +++ b/video_processor/providers/mistral_provider.py | |
| @@ -0,0 +1,167 @@ | |
| 1 | """Mistral AI provider implementation.""" |
| 2 | |
| 3 | import base64 |
| 4 | import logging |
| 5 | import os |
| 6 | from pathlib import Path |
| 7 | from typing import Optional |
| 8 | |
| 9 | from dotenv import load_dotenv |
| 10 | |
| 11 | from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry |
| 12 | |
| 13 | load_dotenv() |
| 14 | logger = logging.getLogger(__name__) |
| 15 | |
| 16 | # Curated list of Mistral models |
| 17 | _MISTRAL_MODELS = [ |
| 18 | ModelInfo( |
| 19 | id="mistral-large-latest", |
| 20 | provider="mistral", |
| 21 | display_name="Mistral Large", |
| 22 | capabilities=["chat"], |
| 23 | ), |
| 24 | ModelInfo( |
| 25 | id="mistral-medium-latest", |
| 26 | provider="mistral", |
| 27 | display_name="Mistral Medium", |
| 28 | capabilities=["chat"], |
| 29 | ), |
| 30 | ModelInfo( |
| 31 | id="mistral-small-latest", |
| 32 | provider="mistral", |
| 33 | display_name="Mistral Small", |
| 34 | capabilities=["chat"], |
| 35 | ), |
| 36 | ModelInfo( |
| 37 | id="open-mistral-nemo", |
| 38 | provider="mistral", |
| 39 | display_name="Mistral Nemo", |
| 40 | capabilities=["chat"], |
| 41 | ), |
| 42 | ModelInfo( |
| 43 | id="pixtral-large-latest", |
| 44 | provider="mistral", |
| 45 | display_name="Pixtral Large", |
| 46 | capabilities=["chat", "vision"], |
| 47 | ), |
| 48 | ModelInfo( |
| 49 | id="pixtral-12b-2409", |
| 50 | provider="mistral", |
| 51 | display_name="Pixtral 12B", |
| 52 | capabilities=["chat", "vision"], |
| 53 | ), |
| 54 | ModelInfo( |
| 55 | id="codestral-latest", |
| 56 | provider="mistral", |
| 57 | display_name="Codestral", |
| 58 | capabilities=["chat"], |
| 59 | ), |
| 60 | ] |
| 61 | |
| 62 | |
| 63 | class MistralProvider(BaseProvider): |
| 64 | """Mistral AI provider using the mistralai SDK.""" |
| 65 | |
| 66 | provider_name = "mistral" |
| 67 | |
| 68 | def __init__(self, api_key: Optional[str] = None): |
| 69 | try: |
| 70 | from mistralai import Mistral |
| 71 | except ImportError: |
| 72 | raise ImportError( |
| 73 | "mistralai package not installed. Install with: pip install mistralai" |
| 74 | ) |
| 75 | |
| 76 | self._api_key = api_key or os.getenv("MISTRAL_API_KEY") |
| 77 | if not self._api_key: |
| 78 | raise ValueError("MISTRAL_API_KEY not set") |
| 79 | |
| 80 | self._client = Mistral(api_key=self._api_key) |
| 81 | self._last_usage = {} |
| 82 | |
| 83 | def chat( |
| 84 | self, |
| 85 | messages: list[dict], |
| 86 | max_tokens: int = 4096, |
| 87 | temperature: float = 0.7, |
| 88 | model: Optional[str] = None, |
| 89 | ) -> str: |
| 90 | model = model or "mistral-large-latest" |
| 91 | |
| 92 | response = self._client.chat.complete( |
| 93 | model=model, |
| 94 | messages=messages, |
| 95 | max_tokens=max_tokens, |
| 96 | temperature=temperature, |
| 97 | ) |
| 98 | |
| 99 | self._last_usage = { |
| 100 | "input_tokens": getattr(response.usage, "prompt_tokens", 0) if response.usage else 0, |
| 101 | "output_tokens": getattr(response.usage, "completion_tokens", 0) |
| 102 | if response.usage |
| 103 | else 0, |
| 104 | } |
| 105 | return response.choices[0].message.content or "" |
| 106 | |
| 107 | def analyze_image( |
| 108 | self, |
| 109 | image_bytes: bytes, |
| 110 | prompt: str, |
| 111 | max_tokens: int = 4096, |
| 112 | model: Optional[str] = None, |
| 113 | ) -> str: |
| 114 | model = model or "pixtral-large-latest" |
| 115 | b64 = base64.b64encode(image_bytes).decode() |
| 116 | |
| 117 | response = self._client.chat.complete( |
| 118 | model=model, |
| 119 | messages=[ |
| 120 | { |
| 121 | "role": "user", |
| 122 | "content": [ |
| 123 | {"type": "text", "text": prompt}, |
| 124 | { |
| 125 | "type": "image_url", |
| 126 | "image_url": {"url": f"data:image/jpeg;base64,{b64}"}, |
| 127 | }, |
| 128 | ], |
| 129 | } |
| 130 | ], |
| 131 | max_tokens=max_tokens, |
| 132 | ) |
| 133 | |
| 134 | self._last_usage = { |
| 135 | "input_tokens": getattr(response.usage, "prompt_tokens", 0) if response.usage else 0, |
| 136 | "output_tokens": getattr(response.usage, "completion_tokens", 0) |
| 137 | if response.usage |
| 138 | else 0, |
| 139 | } |
| 140 | return response.choices[0].message.content or "" |
| 141 | |
| 142 | def transcribe_audio( |
| 143 | self, |
| 144 | audio_path: str | Path, |
| 145 | language: Optional[str] = None, |
| 146 | model: Optional[str] = None, |
| 147 | ) -> dict: |
| 148 | raise NotImplementedError( |
| 149 | "Mistral does not provide a transcription API. " |
| 150 | "Use OpenAI Whisper or Gemini for transcription." |
| 151 | ) |
| 152 | |
| 153 | def list_models(self) -> list[ModelInfo]: |
| 154 | return list(_MISTRAL_MODELS) |
| 155 | |
| 156 | |
| 157 | ProviderRegistry.register( |
| 158 | name="mistral", |
| 159 | provider_class=MistralProvider, |
| 160 | env_var="MISTRAL_API_KEY", |
| 161 | model_prefixes=["mistral-", "pixtral-", "codestral-", "open-mistral-"], |
| 162 | default_models={ |
| 163 | "chat": "mistral-large-latest", |
| 164 | "vision": "pixtral-large-latest", |
| 165 | "audio": "", |
| 166 | }, |
| 167 | ) |
| --- a/video_processor/providers/qianfan_provider.py | ||
| +++ b/video_processor/providers/qianfan_provider.py | ||
| @@ -0,0 +1,138 @@ | ||
| 1 | +"""Baidu Qianfan (ERNIE) provider implementation.""" | |
| 2 | + | |
| 3 | +import logging | |
| 4 | +import os | |
| 5 | +from pathlib import Path | |
| 6 | +from typing import Optional | |
| 7 | + | |
| 8 | +from dotenv import load_dotenv | |
| 9 | + | |
| 10 | +from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry | |
| 11 | + | |
| 12 | +load_dotenv() | |
| 13 | +logger = logging.getLogger(__name__) | |
| 14 | + | |
| 15 | +# Curated list of Qianfan models | |
| 16 | +_QIANFAN_MODELS = [ | |
| 17 | + ModelInfo( | |
| 18 | + id="ernie-4.0-8k", | |
| 19 | + provider="qianfan", | |
| 20 | + display_name="ERNIE 4.0 8K", | |
| 21 | + capabilities=["chat"], | |
| 22 | + ), | |
| 23 | + ModelInfo( | |
| 24 | + id="ernie-3.5-8k", | |
| 25 | + provider="qianfan", | |
| 26 | + display_name="ERNIE 3.5 8K", | |
| 27 | + capabilities=["chat"], | |
| 28 | + ), | |
| 29 | + ModelInfo( | |
| 30 | + id="ernie-speed-8k", | |
| 31 | + provider="qianfan", | |
| 32 | + display_name="ERNIE Speed 8K", | |
| 33 | + capabilities=["chat"], | |
| 34 | + ), | |
| 35 | + ModelInfo( | |
| 36 | + id="ernie-lite-8k", | |
| 37 | + provider="qianfan", | |
| 38 | + display_name="ERNIE Lite 8K", | |
| 39 | + capabilities=["chat"], | |
| 40 | + ), | |
| 41 | +] | |
| 42 | + | |
| 43 | + | |
| 44 | +class QianfanProvider(BaseProvider): | |
| 45 | + """Baidu Qianfan provider using the qianfan SDK.""" | |
| 46 | + | |
| 47 | + provider_name = "qianfan" | |
| 48 | + | |
| 49 | + def __init__( | |
| 50 | + self, | |
| 51 | + access_key: Optional[str] = None, | |
| 52 | + secret_key: Optional[str] = None, | |
| 53 | + ): | |
| 54 | + try: | |
| 55 | + import qianfan | |
| 56 | + except ImportError: | |
| 57 | + raise ImportError("qianfan package not installed. Install with: pip install qianfan") | |
| 58 | + | |
| 59 | + self._access_key = access_key or os.getenv("QIANFAN_ACCESS_KEY") | |
| 60 | + self._secret_key = secret_key or os.getenv("QIANFAN_SECRET_KEY") | |
| 61 | + | |
| 62 | + if not self._access_key or not self._secret_key: | |
| 63 | + raise ValueError("QIANFAN_ACCESS_KEY and QIANFAN_SECRET_KEY must both be set") | |
| 64 | + | |
| 65 | + # Set env vars for the SDK to pick up | |
| 66 | + os.environ["QIANFAN_ACCESS_KEY"] = self._access_key | |
| 67 | + os.environ["QIANFAN_SECRET_KEY"] = self._secret_key | |
| 68 | + | |
| 69 | + self._qianfan = qianfan | |
| 70 | + self._last_usage = {} | |
| 71 | + | |
| 72 | + def chat( | |
| 73 | + self, | |
| 74 | + messages: list[dict], | |
| 75 | + max_tokens: int = 4096, | |
| 76 | + temperature: float = 0.7, | |
| 77 | + model: Optional[str] = None, | |
| 78 | + ) -> str: | |
| 79 | + model = model or "ernie-4.0-8k" | |
| 80 | + if model.startswith("qianfan/"): | |
| 81 | + model = model[len("qianfan/") :] | |
| 82 | + | |
| 83 | + chat_comp = self._qianfan.ChatCompletion() | |
| 84 | + response = chat_comp.do( | |
| 85 | + model=model, | |
| 86 | + messages=messages, | |
| 87 | + temperature=temperature, | |
| 88 | + max_output_tokens=max_tokens, | |
| 89 | + ) | |
| 90 | + | |
| 91 | + body = response.get("body", response) if hasattr(response, "get") else response | |
| 92 | + usage = body.get("usage", {}) if hasattr(body, "get") else {} | |
| 93 | + self._last_usage = { | |
| 94 | + "input_tokens": usage.get("prompt_tokens", 0), | |
| 95 | + "output_tokens": usage.get("completion_tokens", 0), | |
| 96 | + } | |
| 97 | + | |
| 98 | + result = body.get("result", "") if hasattr(body, "get") else str(body) | |
| 99 | + return result | |
| 100 | + | |
| 101 | + def analyze_image( | |
| 102 | + self, | |
| 103 | + image_bytes: bytes, | |
| 104 | + prompt: str, | |
| 105 | + max_tokens: int = 4096, | |
| 106 | + model: Optional[str] = None, | |
| 107 | + ) -> str: | |
| 108 | + raise NotImplementedError( | |
| 109 | + "Qianfan image analysis is not supported in this provider. " | |
| 110 | + "Use OpenAI, Anthropic, or Gemini for image analysis." | |
| 111 | + ) | |
| 112 | + | |
| 113 | + def transcribe_audio( | |
| 114 | + self, | |
| 115 | + audio_path: str | Path, | |
| 116 | + language: Optional[str] = None, | |
| 117 | + model: Optional[str] = None, | |
| 118 | + ) -> dict: | |
| 119 | + raise NotImplementedError( | |
| 120 | + "Qianfan does not provide a transcription API through this provider. " | |
| 121 | + "Use OpenAI Whisper or Gemini for transcription." | |
| 122 | + ) | |
| 123 | + | |
| 124 | + def list_models(self) -> list[ModelInfo]: | |
| 125 | + return list(_QIANFAN_MODELS) | |
| 126 | + | |
| 127 | + | |
| 128 | +ProviderRegistry.register( | |
| 129 | + name="qianfan", | |
| 130 | + provider_class=QianfanProvider, | |
| 131 | + env_var="QIANFAN_ACCESS_KEY", | |
| 132 | + model_prefixes=["ernie-", "qianfan/"], | |
| 133 | + default_models={ | |
| 134 | + "chat": "ernie-4.0-8k", | |
| 135 | + "vision": "", | |
| 136 | + "audio": "", | |
| 137 | + }, | |
| 138 | +) |
| --- a/video_processor/providers/qianfan_provider.py | |
| +++ b/video_processor/providers/qianfan_provider.py | |
| @@ -0,0 +1,138 @@ | |
| --- a/video_processor/providers/qianfan_provider.py | |
| +++ b/video_processor/providers/qianfan_provider.py | |
| @@ -0,0 +1,138 @@ | |
| 1 | """Baidu Qianfan (ERNIE) provider implementation.""" |
| 2 | |
| 3 | import logging |
| 4 | import os |
| 5 | from pathlib import Path |
| 6 | from typing import Optional |
| 7 | |
| 8 | from dotenv import load_dotenv |
| 9 | |
| 10 | from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry |
| 11 | |
| 12 | load_dotenv() |
| 13 | logger = logging.getLogger(__name__) |
| 14 | |
| 15 | # Curated list of Qianfan models |
| 16 | _QIANFAN_MODELS = [ |
| 17 | ModelInfo( |
| 18 | id="ernie-4.0-8k", |
| 19 | provider="qianfan", |
| 20 | display_name="ERNIE 4.0 8K", |
| 21 | capabilities=["chat"], |
| 22 | ), |
| 23 | ModelInfo( |
| 24 | id="ernie-3.5-8k", |
| 25 | provider="qianfan", |
| 26 | display_name="ERNIE 3.5 8K", |
| 27 | capabilities=["chat"], |
| 28 | ), |
| 29 | ModelInfo( |
| 30 | id="ernie-speed-8k", |
| 31 | provider="qianfan", |
| 32 | display_name="ERNIE Speed 8K", |
| 33 | capabilities=["chat"], |
| 34 | ), |
| 35 | ModelInfo( |
| 36 | id="ernie-lite-8k", |
| 37 | provider="qianfan", |
| 38 | display_name="ERNIE Lite 8K", |
| 39 | capabilities=["chat"], |
| 40 | ), |
| 41 | ] |
| 42 | |
| 43 | |
| 44 | class QianfanProvider(BaseProvider): |
| 45 | """Baidu Qianfan provider using the qianfan SDK.""" |
| 46 | |
| 47 | provider_name = "qianfan" |
| 48 | |
| 49 | def __init__( |
| 50 | self, |
| 51 | access_key: Optional[str] = None, |
| 52 | secret_key: Optional[str] = None, |
| 53 | ): |
| 54 | try: |
| 55 | import qianfan |
| 56 | except ImportError: |
| 57 | raise ImportError("qianfan package not installed. Install with: pip install qianfan") |
| 58 | |
| 59 | self._access_key = access_key or os.getenv("QIANFAN_ACCESS_KEY") |
| 60 | self._secret_key = secret_key or os.getenv("QIANFAN_SECRET_KEY") |
| 61 | |
| 62 | if not self._access_key or not self._secret_key: |
| 63 | raise ValueError("QIANFAN_ACCESS_KEY and QIANFAN_SECRET_KEY must both be set") |
| 64 | |
| 65 | # Set env vars for the SDK to pick up |
| 66 | os.environ["QIANFAN_ACCESS_KEY"] = self._access_key |
| 67 | os.environ["QIANFAN_SECRET_KEY"] = self._secret_key |
| 68 | |
| 69 | self._qianfan = qianfan |
| 70 | self._last_usage = {} |
| 71 | |
| 72 | def chat( |
| 73 | self, |
| 74 | messages: list[dict], |
| 75 | max_tokens: int = 4096, |
| 76 | temperature: float = 0.7, |
| 77 | model: Optional[str] = None, |
| 78 | ) -> str: |
| 79 | model = model or "ernie-4.0-8k" |
| 80 | if model.startswith("qianfan/"): |
| 81 | model = model[len("qianfan/") :] |
| 82 | |
| 83 | chat_comp = self._qianfan.ChatCompletion() |
| 84 | response = chat_comp.do( |
| 85 | model=model, |
| 86 | messages=messages, |
| 87 | temperature=temperature, |
| 88 | max_output_tokens=max_tokens, |
| 89 | ) |
| 90 | |
| 91 | body = response.get("body", response) if hasattr(response, "get") else response |
| 92 | usage = body.get("usage", {}) if hasattr(body, "get") else {} |
| 93 | self._last_usage = { |
| 94 | "input_tokens": usage.get("prompt_tokens", 0), |
| 95 | "output_tokens": usage.get("completion_tokens", 0), |
| 96 | } |
| 97 | |
| 98 | result = body.get("result", "") if hasattr(body, "get") else str(body) |
| 99 | return result |
| 100 | |
| 101 | def analyze_image( |
| 102 | self, |
| 103 | image_bytes: bytes, |
| 104 | prompt: str, |
| 105 | max_tokens: int = 4096, |
| 106 | model: Optional[str] = None, |
| 107 | ) -> str: |
| 108 | raise NotImplementedError( |
| 109 | "Qianfan image analysis is not supported in this provider. " |
| 110 | "Use OpenAI, Anthropic, or Gemini for image analysis." |
| 111 | ) |
| 112 | |
| 113 | def transcribe_audio( |
| 114 | self, |
| 115 | audio_path: str | Path, |
| 116 | language: Optional[str] = None, |
| 117 | model: Optional[str] = None, |
| 118 | ) -> dict: |
| 119 | raise NotImplementedError( |
| 120 | "Qianfan does not provide a transcription API through this provider. " |
| 121 | "Use OpenAI Whisper or Gemini for transcription." |
| 122 | ) |
| 123 | |
| 124 | def list_models(self) -> list[ModelInfo]: |
| 125 | return list(_QIANFAN_MODELS) |
| 126 | |
| 127 | |
| 128 | ProviderRegistry.register( |
| 129 | name="qianfan", |
| 130 | provider_class=QianfanProvider, |
| 131 | env_var="QIANFAN_ACCESS_KEY", |
| 132 | model_prefixes=["ernie-", "qianfan/"], |
| 133 | default_models={ |
| 134 | "chat": "ernie-4.0-8k", |
| 135 | "vision": "", |
| 136 | "audio": "", |
| 137 | }, |
| 138 | ) |
| --- a/video_processor/providers/vertex_provider.py | ||
| +++ b/video_processor/providers/vertex_provider.py | ||
| @@ -0,0 +1,226 @@ | ||
| 1 | +"""Google Vertex AI provider implementation.""" | |
| 2 | + | |
| 3 | +import logging | |
| 4 | +import os | |
| 5 | +from pathlib import Path | |
| 6 | +from typing import Optional | |
| 7 | + | |
| 8 | +from dotenv import load_dotenv | |
| 9 | + | |
| 10 | +from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry | |
| 11 | + | |
| 12 | +load_dotenv() | |
| 13 | +logger = logging.getLogger(__name__) | |
| 14 | + | |
| 15 | +# Curated list of models available on Vertex AI | |
| 16 | +_VERTEX_MODELS = [ | |
| 17 | + ModelInfo( | |
| 18 | + id="gemini-2.0-flash", | |
| 19 | + provider="vertex", | |
| 20 | + display_name="Gemini 2.0 Flash", | |
| 21 | + capabilities=["chat", "vision", "audio"], | |
| 22 | + ), | |
| 23 | + ModelInfo( | |
| 24 | + id="gemini-2.0-pro", | |
| 25 | + provider="vertex", | |
| 26 | + display_name="Gemini 2.0 Pro", | |
| 27 | + capabilities=["chat", "vision", "audio"], | |
| 28 | + ), | |
| 29 | + ModelInfo( | |
| 30 | + id="gemini-1.5-pro", | |
| 31 | + provider="vertex", | |
| 32 | + display_name="Gemini 1.5 Pro", | |
| 33 | + capabilities=["chat", "vision", "audio"], | |
| 34 | + ), | |
| 35 | + ModelInfo( | |
| 36 | + id="gemini-1.5-flash", | |
| 37 | + provider="vertex", | |
| 38 | + display_name="Gemini 1.5 Flash", | |
| 39 | + capabilities=["chat", "vision", "audio"], | |
| 40 | + ), | |
| 41 | +] | |
| 42 | + | |
| 43 | + | |
| 44 | +class VertexProvider(BaseProvider): | |
| 45 | + """Google Vertex AI provider using google-genai SDK with Vertex config.""" | |
| 46 | + | |
| 47 | + provider_name = "vertex" | |
| 48 | + | |
| 49 | + def __init__( | |
| 50 | + self, | |
| 51 | + project: Optional[str] = None, | |
| 52 | + location: Optional[str] = None, | |
| 53 | + ): | |
| 54 | + try: | |
| 55 | + from google import genai | |
| 56 | + from google.genai import types # noqa: F401 | |
| 57 | + except ImportError: | |
| 58 | + raise ImportError( | |
| 59 | + "google-cloud-aiplatform or google-genai package not installed. " | |
| 60 | + "Install with: pip install google-cloud-aiplatform" | |
| 61 | + ) | |
| 62 | + | |
| 63 | + self._genai = genai | |
| 64 | + self._project = project or os.getenv("GOOGLE_CLOUD_PROJECT") | |
| 65 | + self._location = location or os.getenv("GOOGLE_CLOUD_REGION", "us-central1") | |
| 66 | + | |
| 67 | + if not self._project: | |
| 68 | + raise ValueError("GOOGLE_CLOUD_PROJECT not set") | |
| 69 | + | |
| 70 | + self.client = genai.Client( | |
| 71 | + vertexai=True, | |
| 72 | + project=self._project, | |
| 73 | + location=self._location, | |
| 74 | + ) | |
| 75 | + self._last_usage = {} | |
| 76 | + | |
| 77 | + def chat( | |
| 78 | + self, | |
| 79 | + messages: list[dict], | |
| 80 | + max_tokens: int = 4096, | |
| 81 | + temperature: float = 0.7, | |
| 82 | + model: Optional[str] = None, | |
| 83 | + ) -> str: | |
| 84 | + from google.genai import types | |
| 85 | + | |
| 86 | + model = model or "gemini-2.0-flash" | |
| 87 | + if model.startswith("vertex/"): | |
| 88 | + model = model[len("vertex/") :] | |
| 89 | + | |
| 90 | + contents = [] | |
| 91 | + for msg in messages: | |
| 92 | + role = "user" if msg["role"] == "user" else "model" | |
| 93 | + contents.append( | |
| 94 | + types.Content( | |
| 95 | + role=role, | |
| 96 | + parts=[types.Part.from_text(text=msg["content"])], | |
| 97 | + ) | |
| 98 | + ) | |
| 99 | + | |
| 100 | + response = self.client.models.generate_content( | |
| 101 | + model=model, | |
| 102 | + contents=contents, | |
| 103 | + config=types.GenerateContentConfig( | |
| 104 | + max_output_tokens=max_tokens, | |
| 105 | + temperature=temperature, | |
| 106 | + ), | |
| 107 | + ) | |
| 108 | + um = getattr(response, "usage_metadata", None) | |
| 109 | + self._last_usage = { | |
| 110 | + "input_tokens": getattr(um, "prompt_token_count", 0) if um else 0, | |
| 111 | + "output_tokens": getattr(um, "candidates_token_count", 0) if um else 0, | |
| 112 | + } | |
| 113 | + return response.text or "" | |
| 114 | + | |
| 115 | + def analyze_image( | |
| 116 | + self, | |
| 117 | + image_bytes: bytes, | |
| 118 | + prompt: str, | |
| 119 | + max_tokens: int = 4096, | |
| 120 | + model: Optional[str] = None, | |
| 121 | + ) -> str: | |
| 122 | + from google.genai import types | |
| 123 | + | |
| 124 | + model = model or "gemini-2.0-flash" | |
| 125 | + if model.startswith("vertex/"): | |
| 126 | + model = model[len("vertex/") :] | |
| 127 | + | |
| 128 | + response = self.client.models.generate_content( | |
| 129 | + model=model, | |
| 130 | + contents=[ | |
| 131 | + types.Part.from_bytes(data=image_bytes, mime_type="image/jpeg"), | |
| 132 | + prompt, | |
| 133 | + ], | |
| 134 | + config=types.GenerateContentConfig( | |
| 135 | + max_output_tokens=max_tokens, | |
| 136 | + ), | |
| 137 | + ) | |
| 138 | + um = getattr(response, "usage_metadata", None) | |
| 139 | + self._last_usage = { | |
| 140 | + "input_tokens": getattr(um, "prompt_token_count", 0) if um else 0, | |
| 141 | + "output_tokens": getattr(um, "candidates_token_count", 0) if um else 0, | |
| 142 | + } | |
| 143 | + return response.text or "" | |
| 144 | + | |
| 145 | + def transcribe_audio( | |
| 146 | + self, | |
| 147 | + audio_path: str | Path, | |
| 148 | + language: Optional[str] = None, | |
| 149 | + model: Optional[str] = None, | |
| 150 | + ) -> dict: | |
| 151 | + import json | |
| 152 | + | |
| 153 | + from google.genai import types | |
| 154 | + | |
| 155 | + model = model or "gemini-2.0-flash" | |
| 156 | + if model.startswith("vertex/"): | |
| 157 | + model = model[len("vertex/") :] | |
| 158 | + | |
| 159 | + audio_path = Path(audio_path) | |
| 160 | + suffix = audio_path.suffix.lower() | |
| 161 | + mime_map = { | |
| 162 | + ".wav": "audio/wav", | |
| 163 | + ".mp3": "audio/mpeg", | |
| 164 | + ".m4a": "audio/mp4", | |
| 165 | + ".flac": "audio/flac", | |
| 166 | + ".ogg": "audio/ogg", | |
| 167 | + ".webm": "audio/webm", | |
| 168 | + } | |
| 169 | + mime_type = mime_map.get(suffix, "audio/wav") | |
| 170 | + audio_bytes = audio_path.read_bytes() | |
| 171 | + | |
| 172 | + lang_hint = f" The audio is in {language}." if language else "" | |
| 173 | + prompt = ( | |
| 174 | + f"Transcribe this audio accurately.{lang_hint} " | |
| 175 | + "Return a JSON object with keys: " | |
| 176 | + '"text" (full transcript), ' | |
| 177 | + '"segments" (array of {start, end, text} objects with timestamps in seconds).' | |
| 178 | + ) | |
| 179 | + | |
| 180 | + response = self.client.models.generate_content( | |
| 181 | + model=model, | |
| 182 | + contents=[ | |
| 183 | + types.Part.from_bytes(data=audio_bytes, mime_type=mime_type), | |
| 184 | + prompt, | |
| 185 | + ], | |
| 186 | + config=types.GenerateContentConfig( | |
| 187 | + max_output_tokens=8192, | |
| 188 | + response_mime_type="application/json", | |
| 189 | + ), | |
| 190 | + ) | |
| 191 | + | |
| 192 | + try: | |
| 193 | + data = json.loads(response.text) | |
| 194 | + except (json.JSONDecodeError, TypeError): | |
| 195 | + data = {"text": response.text or "", "segments": []} | |
| 196 | + | |
| 197 | + um = getattr(response, "usage_metadata", None) | |
| 198 | + self._last_usage = { | |
| 199 | + "input_tokens": getattr(um, "prompt_token_count", 0) if um else 0, | |
| 200 | + "output_tokens": getattr(um, "candidates_token_count", 0) if um else 0, | |
| 201 | + } | |
| 202 | + | |
| 203 | + return { | |
| 204 | + "text": data.get("text", ""), | |
| 205 | + "segments": data.get("segments", []), | |
| 206 | + "language": language, | |
| 207 | + "duration": None, | |
| 208 | + "provider": "vertex", | |
| 209 | + "model": model, | |
| 210 | + } | |
| 211 | + | |
| 212 | + def list_models(self) -> list[ModelInfo]: | |
| 213 | + return list(_VERTEX_MODELS) | |
| 214 | + | |
| 215 | + | |
| 216 | +ProviderRegistry.register( | |
| 217 | + name="vertex", | |
| 218 | + provider_class=VertexProvider, | |
| 219 | + env_var="GOOGLE_CLOUD_PROJECT", | |
| 220 | + model_prefixes=["vertex/"], | |
| 221 | + default_models={ | |
| 222 | + "chat": "gemini-2.0-flash", | |
| 223 | + "vision": "gemini-2.0-flash", | |
| 224 | + "audio": "gemini-2.0-flash", | |
| 225 | + }, | |
| 226 | +) |
| --- a/video_processor/providers/vertex_provider.py | |
| +++ b/video_processor/providers/vertex_provider.py | |
| @@ -0,0 +1,226 @@ | |
| --- a/video_processor/providers/vertex_provider.py | |
| +++ b/video_processor/providers/vertex_provider.py | |
| @@ -0,0 +1,226 @@ | |
| 1 | """Google Vertex AI provider implementation.""" |
| 2 | |
| 3 | import logging |
| 4 | import os |
| 5 | from pathlib import Path |
| 6 | from typing import Optional |
| 7 | |
| 8 | from dotenv import load_dotenv |
| 9 | |
| 10 | from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry |
| 11 | |
| 12 | load_dotenv() |
| 13 | logger = logging.getLogger(__name__) |
| 14 | |
| 15 | # Curated list of models available on Vertex AI |
| 16 | _VERTEX_MODELS = [ |
| 17 | ModelInfo( |
| 18 | id="gemini-2.0-flash", |
| 19 | provider="vertex", |
| 20 | display_name="Gemini 2.0 Flash", |
| 21 | capabilities=["chat", "vision", "audio"], |
| 22 | ), |
| 23 | ModelInfo( |
| 24 | id="gemini-2.0-pro", |
| 25 | provider="vertex", |
| 26 | display_name="Gemini 2.0 Pro", |
| 27 | capabilities=["chat", "vision", "audio"], |
| 28 | ), |
| 29 | ModelInfo( |
| 30 | id="gemini-1.5-pro", |
| 31 | provider="vertex", |
| 32 | display_name="Gemini 1.5 Pro", |
| 33 | capabilities=["chat", "vision", "audio"], |
| 34 | ), |
| 35 | ModelInfo( |
| 36 | id="gemini-1.5-flash", |
| 37 | provider="vertex", |
| 38 | display_name="Gemini 1.5 Flash", |
| 39 | capabilities=["chat", "vision", "audio"], |
| 40 | ), |
| 41 | ] |
| 42 | |
| 43 | |
| 44 | class VertexProvider(BaseProvider): |
| 45 | """Google Vertex AI provider using google-genai SDK with Vertex config.""" |
| 46 | |
| 47 | provider_name = "vertex" |
| 48 | |
| 49 | def __init__( |
| 50 | self, |
| 51 | project: Optional[str] = None, |
| 52 | location: Optional[str] = None, |
| 53 | ): |
| 54 | try: |
| 55 | from google import genai |
| 56 | from google.genai import types # noqa: F401 |
| 57 | except ImportError: |
| 58 | raise ImportError( |
| 59 | "google-cloud-aiplatform or google-genai package not installed. " |
| 60 | "Install with: pip install google-cloud-aiplatform" |
| 61 | ) |
| 62 | |
| 63 | self._genai = genai |
| 64 | self._project = project or os.getenv("GOOGLE_CLOUD_PROJECT") |
| 65 | self._location = location or os.getenv("GOOGLE_CLOUD_REGION", "us-central1") |
| 66 | |
| 67 | if not self._project: |
| 68 | raise ValueError("GOOGLE_CLOUD_PROJECT not set") |
| 69 | |
| 70 | self.client = genai.Client( |
| 71 | vertexai=True, |
| 72 | project=self._project, |
| 73 | location=self._location, |
| 74 | ) |
| 75 | self._last_usage = {} |
| 76 | |
| 77 | def chat( |
| 78 | self, |
| 79 | messages: list[dict], |
| 80 | max_tokens: int = 4096, |
| 81 | temperature: float = 0.7, |
| 82 | model: Optional[str] = None, |
| 83 | ) -> str: |
| 84 | from google.genai import types |
| 85 | |
| 86 | model = model or "gemini-2.0-flash" |
| 87 | if model.startswith("vertex/"): |
| 88 | model = model[len("vertex/") :] |
| 89 | |
| 90 | contents = [] |
| 91 | for msg in messages: |
| 92 | role = "user" if msg["role"] == "user" else "model" |
| 93 | contents.append( |
| 94 | types.Content( |
| 95 | role=role, |
| 96 | parts=[types.Part.from_text(text=msg["content"])], |
| 97 | ) |
| 98 | ) |
| 99 | |
| 100 | response = self.client.models.generate_content( |
| 101 | model=model, |
| 102 | contents=contents, |
| 103 | config=types.GenerateContentConfig( |
| 104 | max_output_tokens=max_tokens, |
| 105 | temperature=temperature, |
| 106 | ), |
| 107 | ) |
| 108 | um = getattr(response, "usage_metadata", None) |
| 109 | self._last_usage = { |
| 110 | "input_tokens": getattr(um, "prompt_token_count", 0) if um else 0, |
| 111 | "output_tokens": getattr(um, "candidates_token_count", 0) if um else 0, |
| 112 | } |
| 113 | return response.text or "" |
| 114 | |
| 115 | def analyze_image( |
| 116 | self, |
| 117 | image_bytes: bytes, |
| 118 | prompt: str, |
| 119 | max_tokens: int = 4096, |
| 120 | model: Optional[str] = None, |
| 121 | ) -> str: |
| 122 | from google.genai import types |
| 123 | |
| 124 | model = model or "gemini-2.0-flash" |
| 125 | if model.startswith("vertex/"): |
| 126 | model = model[len("vertex/") :] |
| 127 | |
| 128 | response = self.client.models.generate_content( |
| 129 | model=model, |
| 130 | contents=[ |
| 131 | types.Part.from_bytes(data=image_bytes, mime_type="image/jpeg"), |
| 132 | prompt, |
| 133 | ], |
| 134 | config=types.GenerateContentConfig( |
| 135 | max_output_tokens=max_tokens, |
| 136 | ), |
| 137 | ) |
| 138 | um = getattr(response, "usage_metadata", None) |
| 139 | self._last_usage = { |
| 140 | "input_tokens": getattr(um, "prompt_token_count", 0) if um else 0, |
| 141 | "output_tokens": getattr(um, "candidates_token_count", 0) if um else 0, |
| 142 | } |
| 143 | return response.text or "" |
| 144 | |
| 145 | def transcribe_audio( |
| 146 | self, |
| 147 | audio_path: str | Path, |
| 148 | language: Optional[str] = None, |
| 149 | model: Optional[str] = None, |
| 150 | ) -> dict: |
| 151 | import json |
| 152 | |
| 153 | from google.genai import types |
| 154 | |
| 155 | model = model or "gemini-2.0-flash" |
| 156 | if model.startswith("vertex/"): |
| 157 | model = model[len("vertex/") :] |
| 158 | |
| 159 | audio_path = Path(audio_path) |
| 160 | suffix = audio_path.suffix.lower() |
| 161 | mime_map = { |
| 162 | ".wav": "audio/wav", |
| 163 | ".mp3": "audio/mpeg", |
| 164 | ".m4a": "audio/mp4", |
| 165 | ".flac": "audio/flac", |
| 166 | ".ogg": "audio/ogg", |
| 167 | ".webm": "audio/webm", |
| 168 | } |
| 169 | mime_type = mime_map.get(suffix, "audio/wav") |
| 170 | audio_bytes = audio_path.read_bytes() |
| 171 | |
| 172 | lang_hint = f" The audio is in {language}." if language else "" |
| 173 | prompt = ( |
| 174 | f"Transcribe this audio accurately.{lang_hint} " |
| 175 | "Return a JSON object with keys: " |
| 176 | '"text" (full transcript), ' |
| 177 | '"segments" (array of {start, end, text} objects with timestamps in seconds).' |
| 178 | ) |
| 179 | |
| 180 | response = self.client.models.generate_content( |
| 181 | model=model, |
| 182 | contents=[ |
| 183 | types.Part.from_bytes(data=audio_bytes, mime_type=mime_type), |
| 184 | prompt, |
| 185 | ], |
| 186 | config=types.GenerateContentConfig( |
| 187 | max_output_tokens=8192, |
| 188 | response_mime_type="application/json", |
| 189 | ), |
| 190 | ) |
| 191 | |
| 192 | try: |
| 193 | data = json.loads(response.text) |
| 194 | except (json.JSONDecodeError, TypeError): |
| 195 | data = {"text": response.text or "", "segments": []} |
| 196 | |
| 197 | um = getattr(response, "usage_metadata", None) |
| 198 | self._last_usage = { |
| 199 | "input_tokens": getattr(um, "prompt_token_count", 0) if um else 0, |
| 200 | "output_tokens": getattr(um, "candidates_token_count", 0) if um else 0, |
| 201 | } |
| 202 | |
| 203 | return { |
| 204 | "text": data.get("text", ""), |
| 205 | "segments": data.get("segments", []), |
| 206 | "language": language, |
| 207 | "duration": None, |
| 208 | "provider": "vertex", |
| 209 | "model": model, |
| 210 | } |
| 211 | |
| 212 | def list_models(self) -> list[ModelInfo]: |
| 213 | return list(_VERTEX_MODELS) |
| 214 | |
| 215 | |
| 216 | ProviderRegistry.register( |
| 217 | name="vertex", |
| 218 | provider_class=VertexProvider, |
| 219 | env_var="GOOGLE_CLOUD_PROJECT", |
| 220 | model_prefixes=["vertex/"], |
| 221 | default_models={ |
| 222 | "chat": "gemini-2.0-flash", |
| 223 | "vision": "gemini-2.0-flash", |
| 224 | "audio": "gemini-2.0-flash", |
| 225 | }, |
| 226 | ) |