#!/usr/bin/env python3
import requests
import json
import time
import argparse
import os
from typing import Optional, Dict, Any


def get_generation_info(
    generation_id: str,
    api_key: str,
    model: str = "",
    max_retries: int = 3,
    initial_delay: float = 1.0
) -> Optional[Dict[str, Any]]:
    """
    Get detailed information about a specific generation from OpenRouter
    
    Args:
        generation_id: The generation ID to look up
        api_key: OpenRouter API key
        model: The model used, for determining optimal retry parameters
        max_retries: Maximum number of retry attempts
        initial_delay: Initial delay in seconds between retries (will increase with backoff)
    
    Returns:
        Dictionary with generation details including cost, or None if retrieval failed
    """
    
    
    # Adjust retry parameters based on model
    if "gpt-4o" in model and "mini" not in model:
        initial_delay = 1.5
        max_retries = 5
        print(f"Using extended retry parameters for GPT-4o model: initial delay {initial_delay}s, max retries {max_retries}")
    
    headers = {
        "Authorization": f"Bearer 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",
        "HTTP-Referer": "https://llm-olympics.com",
        "X-Title": "LLM Olympics"
    }
    
    # Try to get pricing data with retries
    for retry_count in range(max_retries):
        # Wait with exponential backoff
        wait_time = initial_delay * (1.5 ** retry_count)
        print(f"Waiting {wait_time:.2f}s before attempting to fetch generation info (attempt {retry_count + 1}/{max_retries})...")
        time.sleep(wait_time)
        
        # Try multiple endpoint formats
        for endpoint_format in [
            f"https://openrouter.ai/api/v1/generation?id={generation_id}",
            f"https://openrouter.ai/api/v1/generations/{generation_id}"
        ]:
            try:
                print(f"Trying endpoint: {endpoint_format}")
                response = requests.get(endpoint_format, headers=headers)
                response.raise_for_status()
                
                stats = response.json()
                print(f"Generation stats response (attempt {retry_count + 1}):", json.dumps(stats, indent=2)[:1000] + "..." if len(json.dumps(stats)) > 1000 else json.dumps(stats, indent=2))
                
                # Find the cost data in the response - handle different response formats
                cost_data = None
                if stats.get("data") and isinstance(stats["data"], dict):
                    cost_data = stats["data"]
                elif stats and isinstance(stats, dict) and "total_cost" in stats:
                    cost_data = stats
                
                if cost_data:
                    # Return formatted stats
                    result = {
                        "id": cost_data.get("id", generation_id),
                        "model": cost_data.get("model", model),
                        "generation_time": cost_data.get("generation_time"),
                        "prompt_tokens": cost_data.get("tokens_prompt"),
                        "completion_tokens": cost_data.get("tokens_completion"),
                        "total_tokens": cost_data.get("tokens_prompt", 0) + cost_data.get("tokens_completion", 0),
                        "cost": cost_data.get("total_cost"),
                        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
                        "raw_data": cost_data
                    }
                    
                    # Print success message
                    print(f"\nSuccess! OpenRouter generation with ID {generation_id}:")
                    print(f"- Model: {result['model']}")
                    print(f"- Tokens: {result['total_tokens']} ({result['prompt_tokens']} prompt, {result['completion_tokens']} completion)")
                    print(f"- Cost: ${result['cost']:.6f}")
                    print(f"- Generation time: {result['generation_time']}s\n")
                    
                    return result
                
                print(f"No cost data found in response on attempt {retry_count + 1}, will try next endpoint or retry...")
                
            except requests.exceptions.HTTPError as e:
                if e.response.status_code == 404:
                    print(f"404 Not Found with endpoint {endpoint_format}, trying alternative endpoint...")
                    continue
                print(f"HTTP error with endpoint {endpoint_format}: {e}")
            except Exception as e:
                print(f"Error with endpoint {endpoint_format}: {e}")
    
    print(f"Failed to retrieve generation info after {max_retries} attempts")
    return None


def main():
    parser = argparse.ArgumentParser(description="Get generation information from OpenRouter API")
    parser.add_argument("generation_id", help="Generation ID to look up")
    parser.add_argument("--api-key", help="OpenRouter API key (defaults to OPENROUTER_API_KEY env variable)")
    parser.add_argument("--model", default="", help="Model name (optional, helps optimize retry strategy)")
    parser.add_argument("--output", help="Output file path for JSON results (optional)")
    args = parser.parse_args()
    
    # Get API key from args or environment
    api_key = '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'
    if not api_key:
        raise ValueError("OpenRouter API key not provided. Set OPENROUTER_API_KEY environment variable or use --api-key")
    
    # Mask API key for console output
    masked_key = f"{api_key[:5]}...{api_key[-4:]}" if len(api_key) > 9 else "****"
    print(f"Using OpenRouter API key: {masked_key}")
    
    # Get generation info
    result = get_generation_info(args.generation_id, api_key, args.model)
    
    if result and args.output:
        with open(args.output, 'w') as f:
            json.dump(result, f, indent=2)
        print(f"Full results saved to {args.output}")
    
    # Return appropriate exit code
    return 0 if result else 1


if __name__ == "__main__":
    try:
        exit(main())
    except Exception as e:
        print(f"Error: {e}")
        exit(1) 