Skip to content

Feat/support ollama embedding #948

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions packages/providers/src/embeddings/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { Embeddings } from '@langchain/core/embeddings';
import { OpenAIEmbeddings } from '@langchain/openai';
import { FireworksEmbeddings } from '@langchain/community/embeddings/fireworks';
import { JinaEmbeddings } from './jina';
import { OllamaEmbeddings } from './ollama';
import { BaseProvider } from '../types';

export const getEmbeddings = (provider: BaseProvider, config: EmbeddingModelConfig): Embeddings => {
Expand All @@ -29,6 +30,15 @@ export const getEmbeddings = (provider: BaseProvider, config: EmbeddingModelConf
apiKey: provider.apiKey,
maxRetries: 3,
});
case 'ollama':
return new OllamaEmbeddings({
model: config.modelId,
batchSize: config.batchSize,
dimensions: config.dimensions,
baseUrl: provider.baseUrl || 'http://localhost:11434/v1',
apiKey: provider.apiKey,
maxRetries: 3,
});
default:
throw new Error(`Unsupported embeddings provider: ${provider.providerKey}`);
}
Expand Down
68 changes: 68 additions & 0 deletions packages/providers/src/embeddings/ollama.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import { Embeddings } from '@langchain/core/embeddings';
import { OllamaEmbeddings as LangChainOllamaEmbeddings } from '@langchain/ollama';

export interface OllamaEmbeddingsConfig {
model: string;
batchSize?: number;
maxRetries?: number;
dimensions?: number;
baseUrl: string;
apiKey?: string;
}

export class OllamaEmbeddings extends Embeddings {
private client: LangChainOllamaEmbeddings;

constructor(config: OllamaEmbeddingsConfig) {
super(config);

// Validate required configuration
if (!config.model) {
throw new Error('Ollama embeddings model must be specified');
}

if (!config.baseUrl) {
throw new Error('Ollama baseUrl must be specified');
}

try {
// Create the LangChain OllamaEmbeddings client
this.client = new LangChainOllamaEmbeddings({
model: config.model,
baseUrl: config.baseUrl,
requestOptions: {
useMmap: true,
numThread: 6,
},
});
} catch (error) {
throw new Error(`Failed to initialize Ollama embeddings client: ${error.message}`);
}
}

async embedDocuments(documents: string[]): Promise<number[][]> {
if (!documents || documents.length === 0) {
return [];
}

try {
return await this.client.embedDocuments(documents);
} catch (error) {
console.error(`Ollama embeddings error: ${error.message}`);
throw new Error(`Failed to generate embeddings for documents: ${error.message}`);
}
}

async embedQuery(query: string): Promise<number[]> {
if (!query || query.trim() === '') {
throw new Error('Query text cannot be empty');
}

try {
return await this.client.embedQuery(query);
} catch (error) {
console.error(`Ollama query embedding error: ${error.message}`);
throw new Error(`Failed to generate embedding for query: ${error.message}`);
}
}
}
2 changes: 2 additions & 0 deletions packages/providers/src/reranker/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import { SearchResult, RerankerModelConfig } from '@refly/openapi-schema';
export interface RerankerConfig extends RerankerModelConfig {
/** API credentials */
apiKey: string;
/** Base URL for API endpoint */
baseUrl?: string;
}

/**
Expand Down
8 changes: 8 additions & 0 deletions packages/providers/src/reranker/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { RerankerModelConfig } from '@refly/openapi-schema';
import { BaseReranker } from './base';
import { JinaReranker } from './jina';
import { OllamaReranker } from './ollama';
import { BaseProvider } from '../types';

export const getReranker = (provider: BaseProvider, config: RerankerModelConfig): BaseReranker => {
Expand All @@ -10,6 +11,12 @@ export const getReranker = (provider: BaseProvider, config: RerankerModelConfig)
...config,
apiKey: provider.apiKey,
});
case 'ollama':
return new OllamaReranker({
...config,
apiKey: provider.apiKey,
baseUrl: provider.baseUrl,
});
default:
throw new Error(`Unsupported reranker provider: ${provider.providerKey}`);
}
Expand All @@ -18,3 +25,4 @@ export const getReranker = (provider: BaseProvider, config: RerankerModelConfig)
export * from './base';
export * from './fallback';
export * from './jina';
export * from './ollama';
121 changes: 121 additions & 0 deletions packages/providers/src/reranker/ollama.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import { SearchResult } from '@refly/openapi-schema';
import { BaseReranker } from './base';

interface OllamaCompletionResponse {
model: string;
created_at: string;
response: string;
done: boolean;
context?: number[];
total_duration?: number;
load_duration?: number;
prompt_eval_count?: number;
prompt_eval_duration?: number;
eval_count?: number;
eval_duration?: number;
}

/**
* Ollama-specific implementation of the reranker
* Uses the LLM directly for reranking via completion API
*/
export class OllamaReranker extends BaseReranker {
/**
* Rerank search results using Ollama LLM.
*
* @param query The user query
* @param results The search results to rerank
* @param options Additional options for reranking
* @returns Reranked search results
*/
async rerank(
query: string,
results: SearchResult[],
options?: {
topN?: number;
relevanceThreshold?: number;
modelId?: string;
},
): Promise<SearchResult[]> {
const topN = options?.topN || this.config.topN || 5;
const relevanceThreshold = options?.relevanceThreshold || this.config.relevanceThreshold || 0.5;
const model = options?.modelId || this.config.modelId;

if (results.length === 0) {
return [];
}

// Extract texts from search results
const documents = results.map((result) =>
result.snippets.map((snippet) => snippet.text).join('\n\n'),
);

// For each document, ask the LLM to rate its relevance to the query
try {
const headers: Record<string, string> = {
'Content-Type': 'application/json',
};

if (this.config.apiKey) {
headers.Authorization = `Bearer ${this.config.apiKey}`;
}

// Prepare the base URL
const baseUrl = (this.config.baseUrl || 'http://localhost:11434/api').replace(/\/+$/, '');
const apiEndpoint = `${baseUrl}/generate`;

// Rate each document in parallel
const ratedResults = await Promise.all(
documents.map(async (document, index) => {
const systemPrompt = `You are an expert at determining the relevance of a document to a query.
Rate the relevance of the following document to the query on a scale from 0.0 to 1.0,
where 0.0 is completely irrelevant and 1.0 is perfectly relevant.
Return only the numeric score without any explanation.`;

const userPrompt = `Query: ${query}\n\nDocument: ${document}\n\nRelevance score:`;

const response = await fetch(apiEndpoint, {
method: 'post',
headers,
body: JSON.stringify({
model,
prompt: userPrompt,
system: systemPrompt,
stream: false,
format: 'json',
}),
});

if (!response.ok) {
throw new Error(`Ollama API returned ${response.status}: ${await response.text()}`);
}

const data: OllamaCompletionResponse = await response.json();

// Extract the score from the response
// Clean up the response to ensure we get a valid number
const scoreMatch = data.response.trim().match(/([0-9]*\.?[0-9]+)/);
const score = scoreMatch ? Number.parseFloat(scoreMatch[0]) : 0;

// Ensure the score is within the valid range
const normalizedScore = Math.max(0, Math.min(1, score));

return {
...results[index],
relevanceScore: normalizedScore,
};
}),
);

// Filter by relevance threshold and sort by score
return ratedResults
.filter((result) => (result.relevanceScore ?? 0) >= relevanceThreshold)
.sort((a, b) => (b.relevanceScore ?? 0) - (a.relevanceScore ?? 0))
.slice(0, topN);
} catch (e) {
console.error(`Ollama reranker failed, fallback to default: ${e.stack}`);
// Use the fallback from the base class
return this.defaultFallback(results);
}
}
}
2 changes: 1 addition & 1 deletion packages/utils/src/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ export const providerInfoList: ProviderInfo[] = [
{
key: 'ollama',
name: 'Ollama',
categories: ['llm', 'embedding'],
categories: ['llm', 'embedding', 'reranker'],
fieldConfig: {
apiKey: { presence: 'optional' },
baseUrl: {
Expand Down