PlanOpticon
| 0981a08… | noreply | 1 | """AI21 Labs provider implementation.""" |
| 0981a08… | noreply | 2 | |
| 0981a08… | noreply | 3 | import logging |
| 0981a08… | noreply | 4 | import os |
| 0981a08… | noreply | 5 | from pathlib import Path |
| 0981a08… | noreply | 6 | from typing import Optional |
| 0981a08… | noreply | 7 | |
| 0981a08… | noreply | 8 | from dotenv import load_dotenv |
| 0981a08… | noreply | 9 | |
| 0981a08… | noreply | 10 | from video_processor.providers.base import ModelInfo, OpenAICompatibleProvider, ProviderRegistry |
| 0981a08… | noreply | 11 | |
| 0981a08… | noreply | 12 | load_dotenv() |
| 0981a08… | noreply | 13 | logger = logging.getLogger(__name__) |
| 0981a08… | noreply | 14 | |
| 0981a08… | noreply | 15 | # Curated list of AI21 models |
| 0981a08… | noreply | 16 | _AI21_MODELS = [ |
| 0981a08… | noreply | 17 | ModelInfo( |
| 0981a08… | noreply | 18 | id="jamba-1.5-large", |
| 0981a08… | noreply | 19 | provider="ai21", |
| 0981a08… | noreply | 20 | display_name="Jamba 1.5 Large", |
| 0981a08… | noreply | 21 | capabilities=["chat"], |
| 0981a08… | noreply | 22 | ), |
| 0981a08… | noreply | 23 | ModelInfo( |
| 0981a08… | noreply | 24 | id="jamba-1.5-mini", |
| 0981a08… | noreply | 25 | provider="ai21", |
| 0981a08… | noreply | 26 | display_name="Jamba 1.5 Mini", |
| 0981a08… | noreply | 27 | capabilities=["chat"], |
| 0981a08… | noreply | 28 | ), |
| 0981a08… | noreply | 29 | ModelInfo( |
| 0981a08… | noreply | 30 | id="jamba-instruct", |
| 0981a08… | noreply | 31 | provider="ai21", |
| 0981a08… | noreply | 32 | display_name="Jamba Instruct", |
| 0981a08… | noreply | 33 | capabilities=["chat"], |
| 0981a08… | noreply | 34 | ), |
| 0981a08… | noreply | 35 | ] |
| 0981a08… | noreply | 36 | |
| 0981a08… | noreply | 37 | |
| 0981a08… | noreply | 38 | class AI21Provider(OpenAICompatibleProvider): |
| 0981a08… | noreply | 39 | """AI21 Labs provider using OpenAI-compatible API.""" |
| 0981a08… | noreply | 40 | |
| 0981a08… | noreply | 41 | provider_name = "ai21" |
| 0981a08… | noreply | 42 | base_url = "https://api.ai21.com/studio/v1" |
| 0981a08… | noreply | 43 | env_var = "AI21_API_KEY" |
| 0981a08… | noreply | 44 | |
| 0981a08… | noreply | 45 | def __init__(self, api_key: Optional[str] = None): |
| 0981a08… | noreply | 46 | api_key = api_key or os.getenv("AI21_API_KEY") |
| 0981a08… | noreply | 47 | if not api_key: |
| 0981a08… | noreply | 48 | raise ValueError("AI21_API_KEY not set") |
| 0981a08… | noreply | 49 | super().__init__(api_key=api_key, base_url=self.base_url) |
| 0981a08… | noreply | 50 | |
| 0981a08… | noreply | 51 | def chat( |
| 0981a08… | noreply | 52 | self, |
| 0981a08… | noreply | 53 | messages: list[dict], |
| 0981a08… | noreply | 54 | max_tokens: int = 4096, |
| 0981a08… | noreply | 55 | temperature: float = 0.7, |
| 0981a08… | noreply | 56 | model: Optional[str] = None, |
| 0981a08… | noreply | 57 | ) -> str: |
| 0981a08… | noreply | 58 | model = model or "jamba-1.5-large" |
| 0981a08… | noreply | 59 | return super().chat(messages, max_tokens, temperature, model) |
| 0981a08… | noreply | 60 | |
| 0981a08… | noreply | 61 | def analyze_image( |
| 0981a08… | noreply | 62 | self, |
| 0981a08… | noreply | 63 | image_bytes: bytes, |
| 0981a08… | noreply | 64 | prompt: str, |
| 0981a08… | noreply | 65 | max_tokens: int = 4096, |
| 0981a08… | noreply | 66 | model: Optional[str] = None, |
| 0981a08… | noreply | 67 | ) -> str: |
| 0981a08… | noreply | 68 | raise NotImplementedError( |
| 0981a08… | noreply | 69 | "AI21 does not currently support vision/image analysis. " |
| 0981a08… | noreply | 70 | "Use OpenAI, Anthropic, or Gemini for image analysis." |
| 0981a08… | noreply | 71 | ) |
| 0981a08… | noreply | 72 | |
| 0981a08… | noreply | 73 | def transcribe_audio( |
| 0981a08… | noreply | 74 | self, |
| 0981a08… | noreply | 75 | audio_path: str | Path, |
| 0981a08… | noreply | 76 | language: Optional[str] = None, |
| 0981a08… | noreply | 77 | model: Optional[str] = None, |
| 0981a08… | noreply | 78 | ) -> dict: |
| 0981a08… | noreply | 79 | raise NotImplementedError( |
| 0981a08… | noreply | 80 | "AI21 does not provide a transcription API. " |
| 0981a08… | noreply | 81 | "Use OpenAI Whisper or Gemini for transcription." |
| 0981a08… | noreply | 82 | ) |
| 0981a08… | noreply | 83 | |
| 0981a08… | noreply | 84 | def list_models(self) -> list[ModelInfo]: |
| 0981a08… | noreply | 85 | return list(_AI21_MODELS) |
| 0981a08… | noreply | 86 | |
| 0981a08… | noreply | 87 | |
| 0981a08… | noreply | 88 | ProviderRegistry.register( |
| 0981a08… | noreply | 89 | name="ai21", |
| 0981a08… | noreply | 90 | provider_class=AI21Provider, |
| 0981a08… | noreply | 91 | env_var="AI21_API_KEY", |
| 0981a08… | noreply | 92 | model_prefixes=["jamba-", "j2-"], |
| 0981a08… | noreply | 93 | default_models={ |
| 0981a08… | noreply | 94 | "chat": "jamba-1.5-large", |
| 0981a08… | noreply | 95 | "vision": "", |
| 0981a08… | noreply | 96 | "audio": "", |
| 0981a08… | noreply | 97 | }, |
| 0981a08… | noreply | 98 | ) |