/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.huggingface;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.DimensionAwareEmbeddingModel;
import dev.langchain4j.model.huggingface.FactoryCreator;
import dev.langchain4j.model.huggingface.client.EmbeddingRequest;
import dev.langchain4j.model.huggingface.client.HuggingFaceClient;
import dev.langchain4j.model.huggingface.spi.HuggingFaceClientFactory;
import dev.langchain4j.model.huggingface.spi.HuggingFaceEmbeddingModelBuilderFactory;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.spi.ServiceHelper;
import java.time.Duration;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;

public class HuggingFaceEmbeddingModel
extends DimensionAwareEmbeddingModel {
    private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(15L);
    private final HuggingFaceClient client;
    private final boolean waitForModel;
    private final String modelId;

    public HuggingFaceEmbeddingModel(final String accessToken, final String modelId, Boolean waitForModel, final Duration timeout) {
        if (accessToken == null || accessToken.trim().isEmpty()) {
            throw new IllegalArgumentException("HuggingFace access token must be defined. It can be generated here: https://huggingface.co/settings/tokens");
        }
        this.client = FactoryCreator.FACTORY.create(new HuggingFaceClientFactory.Input(){

            @Override
            public String apiKey() {
                return accessToken;
            }

            @Override
            public String modelId() {
                return modelId == null ? "sentence-transformers/all-MiniLM-L6-v2" : modelId;
            }

            @Override
            public Duration timeout() {
                return timeout == null ? DEFAULT_TIMEOUT : timeout;
            }
        });
        this.waitForModel = waitForModel == null || waitForModel != false;
        this.modelId = modelId;
    }

    public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
        List<String> texts = textSegments.stream().map(TextSegment::text).collect(Collectors.toList());
        return this.embedTexts(texts);
    }

    private Response<List<Embedding>> embedTexts(List<String> texts) {
        EmbeddingRequest request = new EmbeddingRequest(texts, this.waitForModel);
        List<float[]> response = this.client.embed(request);
        List embeddings = response.stream().map(Embedding::from).collect(Collectors.toList());
        return Response.from(embeddings);
    }

    public static HuggingFaceEmbeddingModel withAccessToken(String accessToken) {
        return HuggingFaceEmbeddingModel.builder().accessToken(accessToken).build();
    }

    public static HuggingFaceEmbeddingModelBuilder builder() {
        Iterator iterator = ServiceHelper.loadFactories(HuggingFaceEmbeddingModelBuilderFactory.class).iterator();
        if (iterator.hasNext()) {
            HuggingFaceEmbeddingModelBuilderFactory factory = (HuggingFaceEmbeddingModelBuilderFactory)iterator.next();
            return (HuggingFaceEmbeddingModelBuilder)factory.get();
        }
        return new HuggingFaceEmbeddingModelBuilder();
    }

    public static class HuggingFaceEmbeddingModelBuilder {
        private String accessToken;
        private String modelId;
        private Boolean waitForModel;
        private Duration timeout;

        public HuggingFaceEmbeddingModelBuilder accessToken(String accessToken) {
            this.accessToken = accessToken;
            return this;
        }

        public HuggingFaceEmbeddingModelBuilder modelId(String modelId) {
            this.modelId = modelId;
            return this;
        }

        public HuggingFaceEmbeddingModelBuilder waitForModel(Boolean waitForModel) {
            this.waitForModel = waitForModel;
            return this;
        }

        public HuggingFaceEmbeddingModelBuilder timeout(Duration timeout) {
            this.timeout = timeout;
            return this;
        }

        public HuggingFaceEmbeddingModel build() {
            return new HuggingFaceEmbeddingModel(this.accessToken, this.modelId, this.waitForModel, this.timeout);
        }

        public String toString() {
            return "HuggingFaceEmbeddingModel.HuggingFaceEmbeddingModelBuilder(accessToken=" + this.accessToken + ", modelId=" + this.modelId + ", waitForModel=" + this.waitForModel + ", timeout=" + this.timeout + ")";
        }
    }
}

