import uvicorn
from langchain.chains import LLMChain
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
import datetime
from dotenv import load_dotenv, find_dotenv
import openai
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

from fastapi.middleware.cors import CORSMiddleware
import logging

_ = load_dotenv(find_dotenv())


app = FastAPI(title="LangChain Server",
    version="1.0",
    description="A simple api server using Langchain's Runnable interfaces",)

origins = [
    "http://localhost",
    "http://localhost:8000",
    "http://localhost:3000",
    "http://165.232.178.158:8008",
    "*",
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

current_directory = os.getcwd()
print("Current Directory:", current_directory)

openai.api_key = os.environ['OPENAI_API_KEY']

persist_directory = 'docs/chroma/'
embedding = OpenAIEmbeddings(model="text-embedding-3-large")
vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding)
os.chmod(persist_directory, 0o755)

current_date = datetime.datetime.now().date()
if current_date < datetime.date(2023, 9, 2):
    llm_name = "gpt-3.5-turbo-0301"
else:
    llm_name = "gpt-3.5-turbo"
llm = ChatOpenAI(model_name=llm_name, temperature=0)



logging.basicConfig(level=logging.DEBUG, 
                    format='%(asctime)s - %(levelname)s - %(message)s', 
                    filename='details_log.log', 
                    filemode='w')

class Input(BaseModel):
    question: str
    student_answer: str
    prompt : str

prompt_template = """Based on the context: {context} and the question: {query}, first generate an ideal answer that directly addresses the question. Then, evaluate the student's response: {answer}.If the response is irrelevant or nonsensical, provide a NUDGE, suitable for an A2-level learner, that explains why the student's response does not address the question and guides the student back to the topic.If the response is relevant but lacks clarity or detail, provide a NUDGE, suitable for an A2-level learner, that encourages the student to elaborate on their response, specifically highlighting any missing points or areas for further exploration.
Ensure the nudge is written in normal sentence case, suitable for an A2-level learner, focusing on clear and actionable feedback.
SUMMARY:"""


def summary(que,context,answer,prompt_temp):
    prompt = PromptTemplate(
    input_variables = ["query","context","answer"],
    template = prompt_temp
    )
    chain = LLMChain(llm = llm , prompt = prompt)
    result = chain({"query":que,"context":context,"answer":answer})
    return result

def answer_relevance(student_answer, context_summary):
    vectorizer = TfidfVectorizer()
    tfidf_matrix = vectorizer.fit_transform([student_answer, context_summary])
    similarity_score = cosine_similarity(tfidf_matrix)[0, 1]
    answer_relevance = round(max(0, similarity_score),2)
    return answer_relevance

# json input
def log_track(errortype, msg):
    logging.debug(f'{errortype}: {msg}')  
@app.post("/evaluate_nudge/")
async def evaluate_nudge(input: Input):
    logging.info("function name: evaluate_nudge")
    question = input.question
    student_answer = input.student_answer
    prompt = input.prompt
    contexts = []
    # Check if the student's answer is too short or vague
    if len(student_answer) <= 10:  # Adjust threshold if needed (e.g., for "ok" or very short answers)
        logging.info(f"Short answer detected: {student_answer}")
        return {
            "question": question,
            "student_answer": student_answer,
            "nudge": "Your answer is too short. Please provide a more detailed response that explains your reasoning.",
            "relevance_score": 0 
        }
    if len(prompt) <= 0:
        prompt = prompt_template
        retriever = vectordb.as_retriever(search_kwargs = {"k":8})
        doc = retriever.get_relevant_documents(question)
        print(doc)
        contents = []
        context = ""
        for i in range(len(doc)):
            contents.append(doc[i].page_content)
            context += doc[i].page_content
        contexts.append(contents)
        
        con_sum = summary(question,doc,student_answer,prompt)
        #logging.info(con_sum)
        generated_text = con_sum["text"]
        summaryc, nudge = generated_text.split("NUDGE", 1) if "NUDGE" in generated_text else (generated_text, "")
        relevance_score = answer_relevance(student_answer, summaryc)
        
        if nudge=='':
            log_track('category : Blank Nudge',nudge)
            evaluate_nudge(input)
        #relevance_score = answer_relevance(student_answer, context)
        response_data = {
            "question": question,
            "student_answer": student_answer,
            "context": contexts,
            "Summary" : summaryc,
            "nudge": nudge,
            "relevance_score": relevance_score
        }
        return response_data
    else :
        retriever = vectordb.as_retriever(search_kwargs = {"k":8})
        doc = retriever.get_relevant_documents(question)
        print(doc)
        contents = []
        for i in range(len(doc)):
            contents.append(doc[i].page_content)
        contexts.append(contents)
        con_sum = summary(question,doc,student_answer,prompt)
        relevance_score = answer_relevance(student_answer, con_sum["text"])
        response_data = {
            "question": question,
            "student_answer": student_answer,
            "context": contexts,
            "Prompt_response" : con_sum["text"],
            "relevance_score": relevance_score
        }
        return response_data

    # return con_sum["text"]
    

if __name__ == "__main__":
 uvicorn.run(
 app,
 host="0.0.0.0",
 port=8000,
 ssl_keyfile="/etc/letsencrypt/live/divercityapi.anudip.org/fullchain.pem",
 ssl_certfile="/etc/letsencrypt/live/divercityapi.anudip.org/cert.pem"
 )

# # context = vectordb.similarity_search(question,k=8)
        # embedding_vector = OpenAIEmbeddings(model="text-embedding-3-large").embed_query(question)
        # context = vectordb.similarity_search_by_vector(embedding_vector)





