Skip to content

Commit 665afe9

Browse files
authored
Merge pull request #948 from refly-ai/feat/support-ollama-embedding
Feat/support ollama embedding
2 parents 6071c70 + 42a885d commit 665afe9

File tree

6 files changed

+210
-1
lines changed

6 files changed

+210
-1
lines changed

packages/providers/src/embeddings/index.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { Embeddings } from '@langchain/core/embeddings';
33
import { OpenAIEmbeddings } from '@langchain/openai';
44
import { FireworksEmbeddings } from '@langchain/community/embeddings/fireworks';
55
import { JinaEmbeddings } from './jina';
6+
import { OllamaEmbeddings } from './ollama';
67
import { BaseProvider } from '../types';
78

89
export const getEmbeddings = (provider: BaseProvider, config: EmbeddingModelConfig): Embeddings => {
@@ -29,6 +30,15 @@ export const getEmbeddings = (provider: BaseProvider, config: EmbeddingModelConf
2930
apiKey: provider.apiKey,
3031
maxRetries: 3,
3132
});
33+
case 'ollama':
34+
return new OllamaEmbeddings({
35+
model: config.modelId,
36+
batchSize: config.batchSize,
37+
dimensions: config.dimensions,
38+
baseUrl: provider.baseUrl || 'http://localhost:11434/v1',
39+
apiKey: provider.apiKey,
40+
maxRetries: 3,
41+
});
3242
default:
3343
throw new Error(`Unsupported embeddings provider: ${provider.providerKey}`);
3444
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import { Embeddings } from '@langchain/core/embeddings';
2+
import { OllamaEmbeddings as LangChainOllamaEmbeddings } from '@langchain/ollama';
3+
4+
export interface OllamaEmbeddingsConfig {
5+
model: string;
6+
batchSize?: number;
7+
maxRetries?: number;
8+
dimensions?: number;
9+
baseUrl: string;
10+
apiKey?: string;
11+
}
12+
13+
export class OllamaEmbeddings extends Embeddings {
14+
private client: LangChainOllamaEmbeddings;
15+
16+
constructor(config: OllamaEmbeddingsConfig) {
17+
super(config);
18+
19+
// Validate required configuration
20+
if (!config.model) {
21+
throw new Error('Ollama embeddings model must be specified');
22+
}
23+
24+
if (!config.baseUrl) {
25+
throw new Error('Ollama baseUrl must be specified');
26+
}
27+
28+
try {
29+
// Create the LangChain OllamaEmbeddings client
30+
this.client = new LangChainOllamaEmbeddings({
31+
model: config.model,
32+
baseUrl: config.baseUrl,
33+
requestOptions: {
34+
useMmap: true,
35+
numThread: 6,
36+
},
37+
});
38+
} catch (error) {
39+
throw new Error(`Failed to initialize Ollama embeddings client: ${error.message}`);
40+
}
41+
}
42+
43+
async embedDocuments(documents: string[]): Promise<number[][]> {
44+
if (!documents || documents.length === 0) {
45+
return [];
46+
}
47+
48+
try {
49+
return await this.client.embedDocuments(documents);
50+
} catch (error) {
51+
console.error(`Ollama embeddings error: ${error.message}`);
52+
throw new Error(`Failed to generate embeddings for documents: ${error.message}`);
53+
}
54+
}
55+
56+
async embedQuery(query: string): Promise<number[]> {
57+
if (!query || query.trim() === '') {
58+
throw new Error('Query text cannot be empty');
59+
}
60+
61+
try {
62+
return await this.client.embedQuery(query);
63+
} catch (error) {
64+
console.error(`Ollama query embedding error: ${error.message}`);
65+
throw new Error(`Failed to generate embedding for query: ${error.message}`);
66+
}
67+
}
68+
}

packages/providers/src/reranker/base.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import { SearchResult, RerankerModelConfig } from '@refly/openapi-schema';
66
export interface RerankerConfig extends RerankerModelConfig {
77
/** API credentials */
88
apiKey: string;
9+
/** Base URL for API endpoint */
10+
baseUrl?: string;
911
}
1012

1113
/**

packages/providers/src/reranker/index.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { RerankerModelConfig } from '@refly/openapi-schema';
22
import { BaseReranker } from './base';
33
import { JinaReranker } from './jina';
4+
import { OllamaReranker } from './ollama';
45
import { BaseProvider } from '../types';
56

67
export const getReranker = (provider: BaseProvider, config: RerankerModelConfig): BaseReranker => {
@@ -10,6 +11,12 @@ export const getReranker = (provider: BaseProvider, config: RerankerModelConfig)
1011
...config,
1112
apiKey: provider.apiKey,
1213
});
14+
case 'ollama':
15+
return new OllamaReranker({
16+
...config,
17+
apiKey: provider.apiKey,
18+
baseUrl: provider.baseUrl,
19+
});
1320
default:
1421
throw new Error(`Unsupported reranker provider: ${provider.providerKey}`);
1522
}
@@ -18,3 +25,4 @@ export const getReranker = (provider: BaseProvider, config: RerankerModelConfig)
1825
export * from './base';
1926
export * from './fallback';
2027
export * from './jina';
28+
export * from './ollama';
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import { SearchResult } from '@refly/openapi-schema';
2+
import { BaseReranker } from './base';
3+
4+
interface OllamaCompletionResponse {
5+
model: string;
6+
created_at: string;
7+
response: string;
8+
done: boolean;
9+
context?: number[];
10+
total_duration?: number;
11+
load_duration?: number;
12+
prompt_eval_count?: number;
13+
prompt_eval_duration?: number;
14+
eval_count?: number;
15+
eval_duration?: number;
16+
}
17+
18+
/**
19+
* Ollama-specific implementation of the reranker
20+
* Uses the LLM directly for reranking via completion API
21+
*/
22+
export class OllamaReranker extends BaseReranker {
23+
/**
24+
* Rerank search results using Ollama LLM.
25+
*
26+
* @param query The user query
27+
* @param results The search results to rerank
28+
* @param options Additional options for reranking
29+
* @returns Reranked search results
30+
*/
31+
async rerank(
32+
query: string,
33+
results: SearchResult[],
34+
options?: {
35+
topN?: number;
36+
relevanceThreshold?: number;
37+
modelId?: string;
38+
},
39+
): Promise<SearchResult[]> {
40+
const topN = options?.topN || this.config.topN || 5;
41+
const relevanceThreshold = options?.relevanceThreshold || this.config.relevanceThreshold || 0.5;
42+
const model = options?.modelId || this.config.modelId;
43+
44+
if (results.length === 0) {
45+
return [];
46+
}
47+
48+
// Extract texts from search results
49+
const documents = results.map((result) =>
50+
result.snippets.map((snippet) => snippet.text).join('\n\n'),
51+
);
52+
53+
// For each document, ask the LLM to rate its relevance to the query
54+
try {
55+
const headers: Record<string, string> = {
56+
'Content-Type': 'application/json',
57+
};
58+
59+
if (this.config.apiKey) {
60+
headers.Authorization = `Bearer ${this.config.apiKey}`;
61+
}
62+
63+
// Prepare the base URL
64+
const baseUrl = (this.config.baseUrl || 'http://localhost:11434/api').replace(/\/+$/, '');
65+
const apiEndpoint = `${baseUrl}/generate`;
66+
67+
// Rate each document in parallel
68+
const ratedResults = await Promise.all(
69+
documents.map(async (document, index) => {
70+
const systemPrompt = `You are an expert at determining the relevance of a document to a query.
71+
Rate the relevance of the following document to the query on a scale from 0.0 to 1.0,
72+
where 0.0 is completely irrelevant and 1.0 is perfectly relevant.
73+
Return only the numeric score without any explanation.`;
74+
75+
const userPrompt = `Query: ${query}\n\nDocument: ${document}\n\nRelevance score:`;
76+
77+
const response = await fetch(apiEndpoint, {
78+
method: 'post',
79+
headers,
80+
body: JSON.stringify({
81+
model,
82+
prompt: userPrompt,
83+
system: systemPrompt,
84+
stream: false,
85+
format: 'json',
86+
}),
87+
});
88+
89+
if (!response.ok) {
90+
throw new Error(`Ollama API returned ${response.status}: ${await response.text()}`);
91+
}
92+
93+
const data: OllamaCompletionResponse = await response.json();
94+
95+
// Extract the score from the response
96+
// Clean up the response to ensure we get a valid number
97+
const scoreMatch = data.response.trim().match(/([0-9]*\.?[0-9]+)/);
98+
const score = scoreMatch ? Number.parseFloat(scoreMatch[0]) : 0;
99+
100+
// Ensure the score is within the valid range
101+
const normalizedScore = Math.max(0, Math.min(1, score));
102+
103+
return {
104+
...results[index],
105+
relevanceScore: normalizedScore,
106+
};
107+
}),
108+
);
109+
110+
// Filter by relevance threshold and sort by score
111+
return ratedResults
112+
.filter((result) => (result.relevanceScore ?? 0) >= relevanceThreshold)
113+
.sort((a, b) => (b.relevanceScore ?? 0) - (a.relevanceScore ?? 0))
114+
.slice(0, topN);
115+
} catch (e) {
116+
console.error(`Ollama reranker failed, fallback to default: ${e.stack}`);
117+
// Use the fallback from the base class
118+
return this.defaultFallback(results);
119+
}
120+
}
121+
}

packages/utils/src/provider.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ export const providerInfoList: ProviderInfo[] = [
3030
{
3131
key: 'ollama',
3232
name: 'Ollama',
33-
categories: ['llm', 'embedding'],
33+
categories: ['llm', 'embedding', 'reranker'],
3434
fieldConfig: {
3535
apiKey: { presence: 'optional' },
3636
baseUrl: {

0 commit comments

Comments
 (0)