diff --git a/.prettierrc b/.prettierrc new file mode 100644 index 0000000..d1f6bdf --- /dev/null +++ b/.prettierrc @@ -0,0 +1,10 @@ +{ + "semi": true, + "trailingComma": "es5", + "singleQuote": true, + "printWidth": 100, + "tabWidth": 2, + "useTabs": false, + "bracketSpacing": true, + "arrowParens": "avoid" +} \ No newline at end of file diff --git a/__pycache__/analytics.cpython-311.pyc b/__pycache__/analytics.cpython-311.pyc new file mode 100644 index 0000000..c940d63 Binary files /dev/null and b/__pycache__/analytics.cpython-311.pyc differ diff --git a/__pycache__/analytics.cpython-312.pyc b/__pycache__/analytics.cpython-312.pyc new file mode 100644 index 0000000..99200dc Binary files /dev/null and b/__pycache__/analytics.cpython-312.pyc differ diff --git a/__pycache__/client.cpython-311.pyc b/__pycache__/client.cpython-311.pyc new file mode 100644 index 0000000..cd269d0 Binary files /dev/null and b/__pycache__/client.cpython-311.pyc differ diff --git a/__pycache__/client.cpython-312.pyc b/__pycache__/client.cpython-312.pyc new file mode 100644 index 0000000..d1345f5 Binary files /dev/null and b/__pycache__/client.cpython-312.pyc differ diff --git a/__pycache__/complexity.cpython-311.pyc b/__pycache__/complexity.cpython-311.pyc new file mode 100644 index 0000000..fc81ac6 Binary files /dev/null and b/__pycache__/complexity.cpython-311.pyc differ diff --git a/__pycache__/complexity.cpython-312.pyc b/__pycache__/complexity.cpython-312.pyc new file mode 100644 index 0000000..a790332 Binary files /dev/null and b/__pycache__/complexity.cpython-312.pyc differ diff --git a/__pycache__/examples.cpython-311.pyc b/__pycache__/examples.cpython-311.pyc new file mode 100644 index 0000000..5dc0268 Binary files /dev/null and b/__pycache__/examples.cpython-311.pyc differ diff --git a/__pycache__/examples.cpython-312.pyc b/__pycache__/examples.cpython-312.pyc new file mode 100644 index 0000000..0db9a1f Binary files /dev/null and b/__pycache__/examples.cpython-312.pyc differ diff --git a/__pycache__/format.cpython-311.pyc b/__pycache__/format.cpython-311.pyc new file mode 100644 index 0000000..43a69a5 Binary files /dev/null and b/__pycache__/format.cpython-311.pyc differ diff --git a/__pycache__/format.cpython-312.pyc b/__pycache__/format.cpython-312.pyc new file mode 100644 index 0000000..3786e4c Binary files /dev/null and b/__pycache__/format.cpython-312.pyc differ diff --git a/__pycache__/reasoning.cpython-311.pyc b/__pycache__/reasoning.cpython-311.pyc new file mode 100644 index 0000000..e027d95 Binary files /dev/null and b/__pycache__/reasoning.cpython-311.pyc differ diff --git a/__pycache__/reasoning.cpython-312.pyc b/__pycache__/reasoning.cpython-312.pyc new file mode 100644 index 0000000..d7cc14c Binary files /dev/null and b/__pycache__/reasoning.cpython-312.pyc differ diff --git a/cod_analytics.db b/cod_analytics.db new file mode 100644 index 0000000..7141249 Binary files /dev/null and b/cod_analytics.db differ diff --git a/cod_examples.db b/cod_examples.db new file mode 100644 index 0000000..9452efa Binary files /dev/null and b/cod_examples.db differ diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..790a5fc --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,21 @@ +version: "3.8" +services: + app: + build: . + ports: + - "3000:3000" + environment: + - COD_DB_URL=postgresql://postgres:postgres@db:5432/cod_db + depends_on: + - db + db: + image: postgres:15-alpine + environment: + - POSTGRES_USER=postgres + - POSTGRES_PASSWORD=postgres + - POSTGRES_DB=cod_db + volumes: + - postgres_data:/var/lib/postgresql/data + +volumes: + postgres_data: diff --git a/eslint.config.ts b/eslint.config.ts new file mode 100644 index 0000000..3f36b61 --- /dev/null +++ b/eslint.config.ts @@ -0,0 +1,68 @@ +import eslint from '@typescript-eslint/eslint-plugin'; +import tsParser from '@typescript-eslint/parser'; +import prettierPlugin from 'eslint-plugin-prettier'; +import prettier from 'eslint-config-prettier'; +import globals from 'globals'; + +const config = [ + { + // Global configuration + ignores: ['dist/**', 'node_modules/**', 'coverage/**'], + }, + { + // TypeScript files configuration + files: ['**/*.ts'], + languageOptions: { + parser: tsParser, + parserOptions: { + ecmaVersion: 'latest' as const, + sourceType: 'module' as const, + }, + globals: { + ...globals.node, + ...globals.es2021, + }, + }, + plugins: { + '@typescript-eslint': eslint, + prettier: prettierPlugin, + }, + rules: { + // Prettier + 'prettier/prettier': 'error', + + // TypeScript + '@typescript-eslint/explicit-function-return-type': 'off', + '@typescript-eslint/no-explicit-any': 'warn', + '@typescript-eslint/no-unused-vars': ['warn', { argsIgnorePattern: '^_' }], + + // General + 'no-console': ['warn', { allow: ['warn', 'error'] }], + + // Include recommended rules + ...eslint.configs.recommended.rules, + ...prettier.rules, + }, + }, + { + // JavaScript files configuration (for config files) + files: ['**/*.js', '**/*.cjs', '**/*.mjs'], + languageOptions: { + ecmaVersion: 'latest' as const, + sourceType: 'module' as const, + globals: { + ...globals.node, + ...globals.es2021, + }, + }, + plugins: { + prettier: prettierPlugin, + }, + rules: { + 'prettier/prettier': 'error', + ...prettier.rules, + }, + }, +] as const; + +export default config; diff --git a/jest.config.js b/jest.config.js new file mode 100644 index 0000000..343557e --- /dev/null +++ b/jest.config.js @@ -0,0 +1,8 @@ +module.exports = { + preset: 'ts-jest', + testEnvironment: 'node', + testMatch: ['**/__tests__/**/*.test.ts'], + moduleFileExtensions: ['ts', 'js'], + transform: { '^.+\.ts$': 'ts-jest' }, + globals: { 'ts-jest': { tsconfig: 'tsconfig.json' } }, +}; diff --git a/prisma/schema.prisma b/prisma/schema.prisma new file mode 100644 index 0000000..6a55845 --- /dev/null +++ b/prisma/schema.prisma @@ -0,0 +1 @@ +datasource db { provider = "postgresql" url = env("COD_DB_URL") } generator client { provider = "prisma-client-js" } model InferenceRecord { id Int @id @default(autoincrement()) timestamp DateTime @default(now()) problemId String problemText String domain String approach String wordLimit Int tokensUsed Int executionTimeMs Float reasoningSteps String answer String expectedAnswer String? isCorrect Int? metaData Json? } model Example { id Int @id @default(autoincrement()) problem String @db.Text reasoning String @db.Text answer String domain String approach String metaData Json? } diff --git a/src/typescript/__tests__/client.test.ts b/src/typescript/__tests__/client.test.ts new file mode 100644 index 0000000..e636e57 --- /dev/null +++ b/src/typescript/__tests__/client.test.ts @@ -0,0 +1,49 @@ +import { ChainOfDraftClient } from '../client'; +import { AnalyticsService } from '../analytics'; +import { ComplexityEstimator } from '../complexity'; +describe('ChainOfDraftClient', () => { + let client: ChainOfDraftClient; + beforeEach(() => { + client = new ChainOfDraftClient(); + }); + test('should solve a simple math problem', async () => { + const result = await client.completions('test-model', 'What is 2+2?', { domain: 'math' }); + expect(result.choices[0].text).toBeDefined(); + }); + test('should handle chat completions', async () => { + const result = await client.chat('test-model', [{ role: 'user', content: 'What is 2+2?' }], { + domain: 'math', + }); + expect(result.choices[0].message.content).toBeDefined(); + }); +}); +describe('ComplexityEstimator', () => { + let estimator: ComplexityEstimator; + beforeEach(() => { + estimator = new ComplexityEstimator(); + }); + test('should estimate problem complexity', async () => { + const complexity = await estimator.estimateComplexity('What is 2+2?', 'math'); + expect(complexity).toBeGreaterThanOrEqual(3); + expect(complexity).toBeLessThanOrEqual(10); + }); +}); +describe('AnalyticsService', () => { + let analytics: AnalyticsService; + beforeEach(() => { + analytics = new AnalyticsService('postgresql://test:test@localhost:5432/test_db'); + }); + test('should record inference', async () => { + const id = await analytics.recordInference( + 'What is 2+2?', + 'math', + 'CoD', + 5, + 100, + 500, + 'Step 1: Add numbers', + '4' + ); + expect(id).toBeDefined(); + }); +}); diff --git a/src/typescript/analytics.ts b/src/typescript/analytics.ts new file mode 100644 index 0000000..611cd14 --- /dev/null +++ b/src/typescript/analytics.ts @@ -0,0 +1,157 @@ +import { PrismaClient } from '@prisma/client'; +// interface InferenceRecord { +// id: number; +// timestamp: Date; +// problemId: string; +// problemText: string; +// domain: string; +// approach: string; +// wordLimit: number; +// tokensUsed: number; +// executionTimeMs: number; +// reasoningSteps: string; +// answer: string; +// expectedAnswer?: string; +// isCorrect?: number; +// metaData?: any; +// } +interface PerformanceStats { + domain: string; + approach: string; + avgTokens: number; + avgTimeMs: number; + accuracy: number | null; + count: number; +} +interface TokenReductionStats { + domain: string; + codAvgTokens: number; + cotAvgTokens: number; + reductionPercentage: number; +} +interface AccuracyComparison { + domain: string; + codAccuracy: number | null; + cotAccuracy: number | null; + accuracyDifference: number | null; +} +export class AnalyticsService { + private prisma: PrismaClient; + constructor(dbUrl?: string) { + this.prisma = new PrismaClient({ + datasources: { + db: { url: dbUrl || process.env.COD_DB_URL || 'postgresql://localhost:5432/cod_analytics' }, + }, + }); + } + async recordInference( + problem: string, + domain: string, + approach: string, + wordLimit: number, + tokensUsed: number, + executionTime: number, + reasoning: string, + answer: string, + expectedAnswer?: string, + metadata?: any + ): Promise { + const problemId = Math.abs( + problem.split('').reduce((acc, char) => { + return (acc * 31 + char.charCodeAt(0)) >>> 0; + }, 0) % + 10 ** 10 + ).toString(); + const record = await this.prisma.inferenceRecord.create({ + data: { + problemId, + problemText: problem, + domain, + approach, + wordLimit, + tokensUsed, + executionTimeMs: executionTime, + reasoningSteps: reasoning, + answer, + expectedAnswer, + isCorrect: expectedAnswer ? this.checkCorrectness(answer, expectedAnswer) : null, + metaData: metadata, + }, + }); + return record.id; + } + private checkCorrectness(answer: string, expectedAnswer: string): number | null { + if (!answer || !expectedAnswer) { + return null; + } + return answer.trim().toLowerCase() === expectedAnswer.trim().toLowerCase() ? 1 : 0; + } + async getPerformanceByDomain(domain?: string): Promise { + const records = await this.prisma.inferenceRecord.groupBy({ + by: ['domain', 'approach'], + _avg: { tokensUsed: true, executionTimeMs: true, isCorrect: true }, + _count: { id: true }, + where: domain ? { domain } : undefined, + }); + return records.map(r => ({ + domain: r.domain, + approach: r.approach, + avgTokens: r._avg.tokensUsed || 0, + avgTimeMs: r._avg.executionTimeMs || 0, + accuracy: r._avg.isCorrect, + count: r._count.id, + })); + } + async getTokenReductionStats(): Promise { + const domains = await this.prisma.inferenceRecord.findMany({ + distinct: ['domain'], + select: { domain: true }, + }); + const results: TokenReductionStats[] = []; + for (const { domain } of domains) { + const [codAvg, cotAvg] = await Promise.all([ + this.prisma.inferenceRecord.aggregate({ + where: { domain, approach: 'CoD' }, + _avg: { tokensUsed: true }, + }), + this.prisma.inferenceRecord.aggregate({ + where: { domain, approach: 'CoT' }, + _avg: { tokensUsed: true }, + }), + ]); + const codAvgTokens = codAvg._avg.tokensUsed || 0; + const cotAvgTokens = cotAvg._avg.tokensUsed || 0; + const reductionPercentage = cotAvgTokens > 0 ? (1 - codAvgTokens / cotAvgTokens) * 100 : 0; + results.push({ domain, codAvgTokens, cotAvgTokens, reductionPercentage }); + } + return results; + } + async getAccuracyComparison(): Promise { + const domains = await this.prisma.inferenceRecord.findMany({ + distinct: ['domain'], + select: { domain: true }, + }); + const results: AccuracyComparison[] = []; + for (const { domain } of domains) { + const [codAccuracy, cotAccuracy] = await Promise.all([ + this.prisma.inferenceRecord.aggregate({ + where: { domain, approach: 'CoD', isCorrect: { not: null } }, + _avg: { isCorrect: true }, + }), + this.prisma.inferenceRecord.aggregate({ + where: { domain, approach: 'CoT', isCorrect: { not: null } }, + _avg: { isCorrect: true }, + }), + ]); + const codAcc = codAccuracy._avg.isCorrect; + const cotAcc = cotAccuracy._avg.isCorrect; + results.push({ + domain, + codAccuracy: codAcc, + cotAccuracy: cotAcc, + accuracyDifference: codAcc !== null && cotAcc !== null ? codAcc - cotAcc : null, + }); + } + return results; + } +} diff --git a/src/typescript/client.ts b/src/typescript/client.ts new file mode 100644 index 0000000..8b54d7f --- /dev/null +++ b/src/typescript/client.ts @@ -0,0 +1,294 @@ +import { Anthropic } from '@anthropic-ai/sdk'; +import { OpenAI } from 'openai'; +import MistralClient from '@mistralai/mistralai'; +import dotenv from 'dotenv'; +import { AnalyticsService } from './analytics'; +import { ComplexityEstimator } from './complexity'; +import { ExampleDatabase } from './examples'; +import { FormatEnforcer } from './format'; +import { ReasoningSelector } from './reasoning'; +import { createCodPrompt, createCotPrompt } from './prompts'; +import { logger } from '../utils/logger'; + +interface ChatMessage { + role: 'user' | 'assistant' | 'system'; + content: string; +} + +dotenv.config(); +interface ClientSettings { + maxWordsPerStep: number; + enforceFormat: boolean; + adaptiveWordLimit: boolean; + trackAnalytics: boolean; + maxTokens: number; + model?: string; +} +interface ChatResponse { + content: string; + usage: { inputTokens: number; outputTokens: number }; +} +class UnifiedLLMClient { + private provider: string; + private model: string; + private client: any; + constructor() { + this.provider = process.env.LLM_PROVIDER?.toLowerCase() || 'anthropic'; + this.model = process.env.LLM_MODEL || 'claude-3-sonnet-20240229'; + if (this.provider === 'anthropic') { + if (!process.env.ANTHROPIC_API_KEY) { + throw new Error('ANTHROPIC_API_KEY is required for Anthropic provider'); + } + this.client = new Anthropic({ + apiKey: process.env.ANTHROPIC_API_KEY, + baseURL: process.env.ANTHROPIC_BASE_URL, + }); + } else if (this.provider === 'openai') { + if (!process.env.OPENAI_API_KEY) { + throw new Error('OPENAI_API_KEY is required for OpenAI provider'); + } + this.client = new OpenAI({ + apiKey: process.env.OPENAI_API_KEY, + baseURL: process.env.OPENAI_BASE_URL, + }); + } else if (this.provider === 'mistral') { + if (!process.env.MISTRAL_API_KEY) { + throw new Error('MISTRAL_API_KEY is required for Mistral provider'); + } + this.client = new MistralClient(process.env.MISTRAL_API_KEY); + } else if (this.provider === 'ollama') { + this.client = { baseURL: process.env.OLLAMA_BASE_URL || 'http://localhost:11434' }; + } else { + throw new Error(`Unsupported LLM provider: ${this.provider}`); + } + } + async getAvailableModels(): Promise { + try { + if (this.provider === 'anthropic') { + const response = await this.client.models.list(); + return response.data.map((model: any) => model.id); + } else if (this.provider === 'openai') { + const response = await this.client.models.list(); + return response.data.map((model: any) => model.id); + } else if (this.provider === 'mistral') { + const response = await this.client.listModels(); + return response.data?.map((model: any) => model.id) || []; + } else if (this.provider === 'ollama') { + const response = await fetch(`${this.client.baseURL}/api/tags`); + const data = await response.json(); + return data.models.map((model: any) => model.name); + } + return []; + } catch (e) { + logger.error('Error fetching models:', e); + return []; + } + } + async chat( + messages: ChatMessage[], + model?: string, + maxTokens?: number, + temperature?: number + ): Promise { + model = model || this.model; + try { + if (this.provider === 'anthropic') { + const response = await this.client.messages.create({ + model, + messages: messages.map(msg => ({ role: msg.role, content: msg.content })), + max_tokens: maxTokens, + temperature, + }); + return { + content: response.content[0].text, + usage: { + inputTokens: response.usage.input_tokens, + outputTokens: response.usage.output_tokens, + }, + }; + } else if (this.provider === 'openai') { + const response = await this.client.chat.completions.create({ + model, + messages, + max_tokens: maxTokens, + temperature, + }); + return { + content: response.choices[0].message.content, + usage: { + inputTokens: response.usage.prompt_tokens, + outputTokens: response.usage.completion_tokens, + }, + }; + } else if (this.provider === 'mistral') { + const response = await this.client.chat({ + model, + messages: messages.map(msg => ({ role: msg.role, content: msg.content })), + max_tokens: maxTokens, + temperature, + }); + return { + content: response.choices[0].message.content, + usage: { + inputTokens: response.usage.prompt_tokens, + outputTokens: response.usage.completion_tokens, + }, + }; + } else if (this.provider === 'ollama') { + const response = await fetch(`${this.client.baseURL}/api/chat`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + model, + messages, + options: { num_predict: maxTokens, temperature }, + }), + }); + const data = await response.json(); + return { content: data.message.content, usage: { inputTokens: 0, outputTokens: 0 } }; + } + throw new Error(`Unsupported provider: ${this.provider}`); + } catch (e) { + logger.error('Error in chat:', e); + throw e; + } + } +} +export class ChainOfDraftClient { + private llmClient: UnifiedLLMClient; + private analytics: AnalyticsService; + private complexityEstimator: ComplexityEstimator; + private exampleDb: ExampleDatabase; + private formatEnforcer: FormatEnforcer; + private reasoningSelector: ReasoningSelector; + private settings: ClientSettings; + constructor(apiKey?: string, baseUrl?: string, settings: Partial = {}) { + this.llmClient = new UnifiedLLMClient(); + this.analytics = new AnalyticsService(); + this.complexityEstimator = new ComplexityEstimator(); + this.exampleDb = new ExampleDatabase(); + this.formatEnforcer = new FormatEnforcer(); + this.reasoningSelector = new ReasoningSelector(this.analytics); + this.settings = { + maxWordsPerStep: 8, + enforceFormat: true, + adaptiveWordLimit: true, + trackAnalytics: true, + maxTokens: 200000, + ...settings, + }; + } + async completions(model?: string, prompt?: string, options: any = {}): Promise { + if (!prompt) { + throw new Error('Prompt is required'); + } + const problem = prompt; + const domain = options.domain || 'general'; + const result = await this.solveWithReasoning(problem, domain, { + model: model || this.settings.model, + ...options, + }); + return { + id: `cod-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`, + object: 'completion', + created: Date.now(), + model: model || this.settings.model, + choices: [{ text: result.answer, index: 0, logprobs: null, finish_reason: 'stop' }], + usage: result.usage, + }; + } + async chat(model?: string, messages?: ChatMessage[], options: any = {}): Promise { + if (!messages || !messages.length) { + throw new Error('Messages are required'); + } + const lastMessage = messages[messages.length - 1]; + const problem = lastMessage.content; + const domain = options.domain || 'general'; + const result = await this.solveWithReasoning(problem, domain, { + model: model || this.settings.model, + ...options, + }); + + // Format messages for Mistral without using Message class + const formattedMessages = messages.map(msg => ({ + role: msg.role, + content: msg.content, + })); + + return { + id: `cod-${Date.now()}-${Math.random().toString(36).substr(2, 9)}`, + object: 'chat.completion', + created: Date.now(), + model: model || this.settings.model, + choices: [{ message: { role: 'assistant', content: result.answer }, finish_reason: 'stop' }], + usage: result.usage, + }; + } + async solveWithReasoning( + problem: string, + domain: string = 'general', + options: any = {} + ): Promise { + const startTime = Date.now(); + const [approach, reason] = await this.reasoningSelector.selectApproach(problem, domain); + const wordLimit = this.settings.adaptiveWordLimit + ? await this.complexityEstimator.estimateComplexity(problem, domain) + : this.settings.maxWordsPerStep; + const examples = await this.exampleDb.getExamples(domain, approach); + const prompt = + approach === 'CoD' + ? createCodPrompt(problem, domain, wordLimit, examples) + : createCotPrompt(problem, domain, examples); + + const messages: ChatMessage[] = [ + { role: 'system', content: prompt.system }, + { role: 'user', content: prompt.user } + ]; + + const response = await this.llmClient.chat( + messages, + options.model, + options.maxTokens, + options.temperature + ); + + let answer = response.content; + if (this.settings.enforceFormat && approach === 'CoD') { + answer = this.formatEnforcer.enforceWordLimit(answer, wordLimit); + } + const executionTime = Date.now() - startTime; + if (this.settings.trackAnalytics) { + await this.analytics.recordInference( + problem, + domain, + approach, + wordLimit, + response.usage.inputTokens + response.usage.outputTokens, + executionTime, + answer.split('####')[0].trim(), + answer.split('####')[1]?.trim() || answer, + options.expectedAnswer, + { reason, ...options.metadata } + ); + } + return { + answer: answer.split('####')[1]?.trim() || answer, + reasoning: answer.split('####')[0].trim(), + approach, + wordLimit, + usage: response.usage, + }; + } + async getPerformanceStats(domain?: string): Promise { + return this.analytics.getPerformanceByDomain(domain); + } + async getTokenReductionStats(): Promise { + return this.analytics.getTokenReductionStats(); + } + updateSettings(settings: Partial): void { + this.settings = { ...this.settings, ...settings }; + } + public async getAvailableModels(): Promise { + return this.llmClient.getAvailableModels(); + } +} diff --git a/src/typescript/complexity.ts b/src/typescript/complexity.ts new file mode 100644 index 0000000..f3c8ef9 --- /dev/null +++ b/src/typescript/complexity.ts @@ -0,0 +1,213 @@ +interface DomainBaseLimits { + [key: string]: number; +} +interface ComplexityIndicators { + [key: string]: string[]; +} +interface ProblemAnalysis { + domain: string; + baseLimit: number; + wordCount: number; + lengthFactor: number; + indicatorCount: number; + foundIndicators: string[]; + indicatorFactor: number; + questionCount: number; + questionFactor: number; + sentenceCount: number; + wordsPerSentence: number; + sentenceComplexityFactor: number; + estimatedComplexity: number; +} +export class ComplexityEstimator { + private domainBaseLimits: DomainBaseLimits = { + math: 6, + logic: 5, + common_sense: 4, + physics: 7, + chemistry: 6, + biology: 5, + code: 8, + puzzle: 5, + general: 5, + }; + private complexityIndicators: ComplexityIndicators = { + math: [ + 'integral', + 'derivative', + 'equation', + 'proof', + 'theorem', + 'calculus', + 'matrix', + 'vector', + 'linear algebra', + 'probability', + 'statistics', + 'geometric series', + 'differential', + 'polynomial', + 'factorial', + ], + logic: [ + 'if and only if', + 'necessary condition', + 'sufficient', + 'contradiction', + 'syllogism', + 'premise', + 'fallacy', + 'converse', + 'counterexample', + 'logical equivalence', + 'negation', + 'disjunction', + 'conjunction', + ], + code: [ + 'recursion', + 'algorithm', + 'complexity', + 'optimization', + 'function', + 'class', + 'object', + 'inheritance', + 'polymorphism', + 'data structure', + 'binary tree', + 'hash table', + 'graph', + 'dynamic programming', + ], + physics: [ + 'quantum', + 'relativity', + 'momentum', + 'force', + 'acceleration', + 'energy', + 'thermodynamics', + 'electric field', + 'magnetic field', + 'potential', + 'entropy', + 'wavelength', + 'frequency', + ], + chemistry: [ + 'reaction', + 'molecule', + 'compound', + 'element', + 'equilibrium', + 'acid', + 'base', + 'oxidation', + 'reduction', + 'catalyst', + 'isomer', + ], + biology: [ + 'gene', + 'protein', + 'enzyme', + 'cell', + 'tissue', + 'organ', + 'system', + 'metabolism', + 'photosynthesis', + 'respiration', + 'homeostasis', + ], + puzzle: [ + 'constraint', + 'sequence', + 'pattern', + 'rules', + 'probability', + 'combination', + 'permutation', + 'optimal', + 'strategy', + ], + }; + async estimateComplexity(problem: string, domain: string = 'general'): Promise { + const baseLimit = this.domainBaseLimits[domain.toLowerCase()] || 5; + const lengthFactor = Math.min(problem.split(' ').length / 50, 2); + let indicatorCount = 0; + const indicators = this.complexityIndicators[domain.toLowerCase()] || []; + for (const indicator of indicators) { + if (problem.toLowerCase().includes(indicator.toLowerCase())) { + indicatorCount++; + } + } + const indicatorFactor = Math.min(1 + indicatorCount * 0.2, 1.8); + const questionFactor = 1 + (problem.split('?').length - 1) * 0.2; + const sentences = problem.split('.').filter(s => s.trim()); + const wordsPerSentence = problem.split(' ').length / Math.max(sentences.length, 1); + const sentenceComplexityFactor = Math.min(wordsPerSentence / 15, 1.5); + let domainFactor = 1.0; + if ( + domain.toLowerCase() === 'math' && + ['prove', 'proof', 'theorem'].some(term => problem.toLowerCase().includes(term)) + ) { + domainFactor = 1.3; + } else if ( + domain.toLowerCase() === 'code' && + ['implement', 'function', 'algorithm'].some(term => problem.toLowerCase().includes(term)) + ) { + domainFactor = 1.2; + } + const impactFactor = Math.max( + lengthFactor, + indicatorFactor, + questionFactor, + sentenceComplexityFactor, + domainFactor + ); + const adjustedLimit = Math.round(baseLimit * impactFactor); + return Math.max(3, Math.min(adjustedLimit, 10)); + } + analyzeProblem(problem: string, domain: string = 'general'): ProblemAnalysis { + const baseLimit = this.domainBaseLimits[domain.toLowerCase()] || 5; + const wordCount = problem.split(' ').length; + const lengthFactor = Math.min(wordCount / 50, 2); + const indicators = this.complexityIndicators[domain.toLowerCase()] || []; + const foundIndicators = indicators.filter(ind => + problem.toLowerCase().includes(ind.toLowerCase()) + ); + const indicatorCount = foundIndicators.length; + const indicatorFactor = Math.min(1 + indicatorCount * 0.2, 1.8); + const questionCount = problem.split('?').length - 1; + const questionFactor = 1 + questionCount * 0.2; + const sentences = problem.split('.').filter(s => s.trim()); + const wordsPerSentence = wordCount / Math.max(sentences.length, 1); + const sentenceComplexityFactor = Math.min(wordsPerSentence / 15, 1.5); + return { + domain, + baseLimit, + wordCount, + lengthFactor, + indicatorCount, + foundIndicators, + indicatorFactor, + questionCount, + questionFactor, + sentenceCount: sentences.length, + wordsPerSentence, + sentenceComplexityFactor, + estimatedComplexity: Math.max( + 3, + Math.min( + Math.round( + baseLimit * + Math.max(lengthFactor, indicatorFactor, questionFactor, sentenceComplexityFactor) + ), + 10 + ) + ), + }; + } +} diff --git a/src/typescript/examples.ts b/src/typescript/examples.ts new file mode 100644 index 0000000..6bb93bb --- /dev/null +++ b/src/typescript/examples.ts @@ -0,0 +1,147 @@ +import { PrismaClient } from '@prisma/client'; +interface Example { + id: number; + problem: string; + reasoning: string; + answer: string; + domain: string; + approach: string; + metaData?: any; +} +interface ExampleCount { + domain: string; + approach: string; + count: number; +} +export class ExampleDatabase { + private prisma: PrismaClient; + constructor(dbPath?: string) { + this.prisma = new PrismaClient({ + datasources: { + db: { + url: dbPath || process.env.COD_EXAMPLES_DB || 'postgresql://localhost:5432/cod_examples', + }, + }, + }); + this.ensureExamplesExist(); + } + private async ensureExamplesExist(): Promise { + const count = await this.prisma.example.count(); + if (count === 0) { + await this.loadInitialExamples(); + } + } + private async loadInitialExamples(): Promise { + const examples = [ + { + problem: + 'Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?', + reasoning: + "Let'\s think through this step by step:\n1. Initially, Jason had 20 lollipops.\n2. After giving some to Denny, Jason now has 12 lollipops.\n3. To find out how many lollipops Jason gave to Denny, we need to calculate the difference between the initial number of lollipops and the remaining number.\n4. We can set up a simple subtraction problem: Initial number of lollipops - Remaining number of lollipops = Lollipops given to Denny\n5. Putting in the numbers: 20 - 12 = Lollipops given to Denny\n6. Solving the subtraction: 20 - 12 = 8", + answer: '8 lollipops', + domain: 'math', + approach: 'CoT', + }, + { + problem: + 'Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?', + reasoning: 'Initial: 20 lollipops\nRemaining: 12 lollipops\nGave away: 20-12=8 lollipops', + answer: '8 lollipops', + domain: 'math', + approach: 'CoD', + }, + { + problem: + 'A coin is heads up. John flips the coin. Mary flips the coin. Paul flips the coin. Susan does not flip the coin. Is the coin still heads up?', + reasoning: + "Let'\s track the state of the coin through each flip:\n1. Initially, the coin is heads up.\n2. John flips the coin, so it changes from heads to tails.\n3. Mary flips the coin, so it changes from tails to heads.\n4. Paul flips the coin, so it changes from heads to tails.\n5. Susan does not flip the coin, so it remains tails.\nTherefore, the coin is tails up, which means it is not still heads up.", + answer: 'No', + domain: 'logic', + approach: 'CoT', + }, + { + problem: + 'A coin is heads up. John flips the coin. Mary flips the coin. Paul flips the coin. Susan does not flip the coin. Is the coin still heads up?', + reasoning: 'H→J flips→T\nT→M flips→H\nH→P flips→T\nT→S no flip→T\nFinal: tails', + answer: 'No', + domain: 'logic', + approach: 'CoD', + }, + { + problem: + 'A car accelerates from 0 to 60 mph in 5 seconds. What is its acceleration in mph/s?', + reasoning: + "Let'\s solve this problem step by step:\n1. We know the initial velocity is 0 mph.\n2. The final velocity is 60 mph.\n3. The time taken is 5 seconds.\n4. Acceleration is the rate of change of velocity with respect to time.\n5. Using the formula: acceleration = (final velocity - initial velocity) / time\n6. Substituting the values: acceleration = (60 mph - 0 mph) / 5 seconds\n7. Simplifying: acceleration = 60 mph / 5 seconds = 12 mph/s", + answer: '12 mph/s', + domain: 'physics', + approach: 'CoT', + }, + { + problem: + 'A car accelerates from 0 to 60 mph in 5 seconds. What is its acceleration in mph/s?', + reasoning: 'a = Δv/Δt\na = (60-0)/5\na = 12 mph/s', + answer: '12 mph/s', + domain: 'physics', + approach: 'CoD', + }, + ]; + await this.prisma.example.createMany({ data: examples }); + } + async getExamples( + domain: string, + approach: string = 'CoD', + limit: number = 3 + ): Promise { + const examples = await this.prisma.example.findMany({ + where: { domain, approach }, + take: limit, + }); + return examples.map(ex => ({ + id: ex.id, + problem: ex.problem, + reasoning: ex.reasoning, + answer: ex.answer, + domain: ex.domain, + approach: ex.approach, + })); + } + async addExample( + problem: string, + reasoning: string, + answer: string, + domain: string, + approach: string = 'CoD', + metadata?: any + ): Promise { + const example = await this.prisma.example.create({ + data: { problem, reasoning, answer, domain, approach, metaData: metadata }, + }); + return example.id; + } + async transformCotToCod(cotExample: Example, maxWordsPerStep: number = 5): Promise { + const steps = this.extractReasoningSteps(cotExample.reasoning); + const codSteps = steps.map(step => this.summarizeStep(step, maxWordsPerStep)); + return { ...cotExample, reasoning: codSteps.join('\n'), approach: 'CoD' }; + } + private extractReasoningSteps(reasoning: string): string[] { + if (/\d+\./.test(reasoning)) { + return reasoning + .split(/\d+\./) + .filter(s => s.trim()) + .map(s => s.trim()); + } else { + return reasoning.split('\n').filter(s => s.trim()); + } + } + private summarizeStep(step: string, maxWords: number): string { + const words = step.split(/\s+/); + return words.length <= maxWords ? step : words.slice(0, maxWords).join(' '); + } + async getExampleCountByDomain(): Promise { + const results = await this.prisma.example.groupBy({ + by: ['domain', 'approach'], + _count: { id: true }, + }); + return results.map(r => ({ domain: r.domain, approach: r.approach, count: r._count.id })); + } +} diff --git a/src/typescript/format.ts b/src/typescript/format.ts new file mode 100644 index 0000000..01926cb --- /dev/null +++ b/src/typescript/format.ts @@ -0,0 +1,96 @@ +interface AdherenceMetrics { + totalSteps: number; + stepsWithinLimit: number; + averageWordsPerStep: number; + maxWordsInAnyStep: number; + adherenceRate: number; + stepCounts: number[]; +} +export class FormatEnforcer { + private stepPattern: RegExp; + constructor() { + this.stepPattern = new RegExp( + '(\d+\.\s*|Step\s+\d+:|\n-\s+|\n\*\s+|•\s+|^\s*-\s+|^\s*\*\s+)', + 'm' + ); + } + enforceWordLimit(reasoning: string, maxWordsPerStep: number): string { + const steps = this.splitIntoSteps(reasoning); + const enforcedSteps = steps.map(step => this.enforceStep(step, maxWordsPerStep)); + return enforcedSteps.join('\n'); + } + private splitIntoSteps(reasoning: string): string[] { + if (this.stepPattern.test(reasoning)) { + const parts: string[] = []; + let currentPart = ''; + const lines = reasoning.split('\n'); + for (const line of lines) { + if (this.stepPattern.test(line) || /^\s*\d+\./.test(line)) { + if (currentPart) { + parts.push(currentPart); + } + currentPart = line; + } else { + if (currentPart) { + currentPart += '\n' + line; + } else { + currentPart = line; + } + } + } + if (currentPart) { + parts.push(currentPart); + } + return parts.length ? parts : [reasoning]; + } else { + return reasoning + .split('\n') + .map(line => line.trim()) + .filter(Boolean); + } + } + private enforceStep(step: string, maxWords: number): string { + const words = step.split(/\s+/); + if (words.length <= maxWords) { + return step; + } + const match = step.match(/^(\d+\.\s*|Step\s+\d+:|\s*-\s+|\s*\*\s+|•\s+)/); + const marker = match ? match[0] : ''; + const content = marker ? step.slice(marker.length).trim() : step; + const contentWords = content.split(/\s+/); + const truncated = contentWords.slice(0, maxWords).join(' '); + return `${marker}${truncated}`; + } + analyzeAdherence(reasoning: string, maxWordsPerStep: number): AdherenceMetrics { + const steps = this.splitIntoSteps(reasoning); + const stepCounts = steps.map(step => { + const match = step.match(/^(\d+\.\s*|Step\s+\d+:|\s*-\s+|\s*\*\s+|•\s+)/); + const marker = match ? match[0] : ''; + const content = marker ? step.slice(marker.length).trim() : step; + return content.split(/\s+/).length; + }); + const totalSteps = steps.length; + const stepsWithinLimit = stepCounts.filter(count => count <= maxWordsPerStep).length; + const averageWordsPerStep = totalSteps ? stepCounts.reduce((a, b) => a + b, 0) / totalSteps : 0; + const maxWordsInAnyStep = stepCounts.length ? Math.max(...stepCounts) : 0; + const adherenceRate = totalSteps ? stepsWithinLimit / totalSteps : 1.0; + return { + totalSteps, + stepsWithinLimit, + averageWordsPerStep, + maxWordsInAnyStep, + adherenceRate, + stepCounts, + }; + } + formatToNumberedSteps(reasoning: string): string { + const steps = this.splitIntoSteps(reasoning); + return steps + .map((step, i) => { + const match = step.match(/^(\d+\.\s*|Step\s+\d+:|\s*-\s+|\s*\*\s+|•\s+)/); + const content = match ? step.slice(match[0].length).trim() : step.trim(); + return `${i + 1}. ${content}`; + }) + .join('\n'); + } +} diff --git a/src/typescript/index.ts b/src/typescript/index.ts new file mode 100644 index 0000000..d5469a3 --- /dev/null +++ b/src/typescript/index.ts @@ -0,0 +1,14 @@ +/** + * Chain of Draft (CoD) TypeScript Implementation + * Main entry point for the TypeScript port + */ + +export * from './analytics'; +export * from './client'; +export * from './complexity'; +export * from './examples'; +export * from './format'; +export * from './reasoning'; + +// Re-export server for direct usage +export { startServer } from '../server'; diff --git a/src/typescript/prompts.ts b/src/typescript/prompts.ts new file mode 100644 index 0000000..5a88b7c --- /dev/null +++ b/src/typescript/prompts.ts @@ -0,0 +1,65 @@ +interface Example { + problem: string; + reasoning: string; + answer: string; + domain: string; + approach: string; +} + +interface Prompt { + system: string; + user: string; +} + +export function createCodPrompt( + problem: string, + domain: string, + wordLimit: number, + examples: Example[] = [] +): Prompt { + const systemPrompt = `You are a Chain of Draft (CoD) reasoning assistant. Break down problems into clear, concise steps. +Each step should be no more than ${wordLimit} words. +Domain: ${domain}`; + + const userPrompt = `Problem: ${problem} + +${examples.length > 0 ? formatExamples(examples) : ''} +Please solve this step by step, keeping each step under ${wordLimit} words. +End with a clear final answer after "####".`; + + return { + system: systemPrompt, + user: userPrompt, + }; +} + +export function createCotPrompt( + problem: string, + domain: string, + examples: Example[] = [] +): Prompt { + const systemPrompt = `You are a Chain of Thought (CoT) reasoning assistant. Break down problems into clear steps. +Domain: ${domain}`; + + const userPrompt = `Problem: ${problem} + +${examples.length > 0 ? formatExamples(examples) : ''} +Let's solve this step by step. +End with a clear final answer after "####".`; + + return { + system: systemPrompt, + user: userPrompt, + }; +} + +function formatExamples(examples: Example[]): string { + return examples.map(ex => + `Example: +Problem: ${ex.problem} +Reasoning: +${ex.reasoning} +Answer: ${ex.answer} +---` + ).join('\n\n'); +} \ No newline at end of file diff --git a/src/typescript/reasoning.ts b/src/typescript/reasoning.ts new file mode 100644 index 0000000..8aa5481 --- /dev/null +++ b/src/typescript/reasoning.ts @@ -0,0 +1,121 @@ +import { ComplexityEstimator } from './complexity'; +import { AnalyticsService } from './analytics'; +interface DomainPreferences { + complexityThreshold: number; + accuracyThreshold: number; +} +interface Preferences { + [key: string]: DomainPreferences; +} +export class ReasoningSelector { + private analytics: AnalyticsService; + private defaultPreferences: Preferences = { + math: { complexityThreshold: 7, accuracyThreshold: 0.85 }, + code: { complexityThreshold: 8, accuracyThreshold: 0.9 }, + physics: { complexityThreshold: 7, accuracyThreshold: 0.85 }, + chemistry: { complexityThreshold: 7, accuracyThreshold: 0.85 }, + biology: { complexityThreshold: 6, accuracyThreshold: 0.85 }, + logic: { complexityThreshold: 6, accuracyThreshold: 0.9 }, + puzzle: { complexityThreshold: 7, accuracyThreshold: 0.85 }, + default: { complexityThreshold: 6, accuracyThreshold: 0.8 }, + }; + constructor(analyticsService: AnalyticsService) { + this.analytics = analyticsService; + } + async selectApproach( + problem: string, + domain: string, + complexityScore?: number + ): Promise<[string, string]> { + const prefs = this.defaultPreferences[domain.toLowerCase()] || this.defaultPreferences.default; + if (complexityScore === undefined) { + const estimator = new ComplexityEstimator(); + complexityScore = await estimator.estimateComplexity(problem, domain); + } + if (complexityScore > prefs.complexityThreshold) { + return [ + 'CoT', + `Problem complexity (${complexityScore}) exceeds threshold (${prefs.complexityThreshold})`, + ]; + } + const domainPerformance = await this.analytics.getPerformanceByDomain(domain); + const codAccuracy = domainPerformance.find(p => p.approach === 'CoD')?.accuracy || null; + if (codAccuracy !== null && codAccuracy < prefs.accuracyThreshold) { + return [ + 'CoT', + `Historical accuracy with CoD (${codAccuracy.toFixed(2)}) below threshold (${prefs.accuracyThreshold})`, + ]; + } + return ['CoD', 'Default to Chain-of-Draft for efficiency']; + } + updatePreferences( + domain: string, + complexityThreshold?: number, + accuracyThreshold?: number + ): void { + if (!(domain in this.defaultPreferences)) { + this.defaultPreferences[domain] = { ...this.defaultPreferences.default }; + } + if (complexityThreshold !== undefined) { + this.defaultPreferences[domain].complexityThreshold = complexityThreshold; + } + if (accuracyThreshold !== undefined) { + this.defaultPreferences[domain].accuracyThreshold = accuracyThreshold; + } + } + getPreferences(domain?: string): DomainPreferences | Preferences { + return domain + ? this.defaultPreferences[domain] || this.defaultPreferences.default + : this.defaultPreferences; + } +} +interface Example { + problem: string; + reasoning: string; + answer: string; +} +export function createCodPrompt( + problem: string, + domain: string, + maxWordsPerStep: number, + examples?: Example[] +): { system: string; user: string } { + let systemPrompt = `You are an expert problem solver using Chain of Draft reasoning. Think step by step, but only keep a minimum draft for each thinking step, with ${maxWordsPerStep} words at most per step. Return the answer at the end after "####".`; + if (domain.toLowerCase() === 'math') { + systemPrompt += '\nUse mathematical notation to keep steps concise.'; + } else if (domain.toLowerCase() === 'code') { + systemPrompt += '\nUse pseudocode or short code snippets when appropriate.'; + } else if (domain.toLowerCase() === 'physics') { + systemPrompt += '\nUse equations and physical quantities with units.'; + } + let exampleText = ''; + if (examples) { + for (const example of examples) { + exampleText += `\nProblem: ${example.problem}\nSolution:\n${example.reasoning}\n####\n${example.answer}\n`; + } + } + const userPrompt = `Problem: ${problem}`; + return { system: systemPrompt, user: exampleText ? `${exampleText}\n${userPrompt}` : userPrompt }; +} +export function createCotPrompt( + problem: string, + domain: string, + examples?: Example[] +): { system: string; user: string } { + let systemPrompt = `Think step by step to answer the following question. Return the answer at the end of the response after a separator ####.`; + if (domain.toLowerCase() === 'math') { + systemPrompt += '\nMake sure to show all mathematical operations clearly.'; + } else if (domain.toLowerCase() === 'code') { + systemPrompt += '\nBe detailed about algorithms and implementation steps.'; + } else if (domain.toLowerCase() === 'physics') { + systemPrompt += '\nExplain physical principles and equations in detail.'; + } + let exampleText = ''; + if (examples) { + for (const example of examples) { + exampleText += `\nProblem: ${example.problem}\nSolution:\n${example.reasoning}\n####\n${example.answer}\n`; + } + } + const userPrompt = `Problem: ${problem}`; + return { system: systemPrompt, user: exampleText ? `${exampleText}\n${userPrompt}` : userPrompt }; +} diff --git a/src/typescript/server.ts b/src/typescript/server.ts new file mode 100644 index 0000000..8a2427a --- /dev/null +++ b/src/typescript/server.ts @@ -0,0 +1,79 @@ +import express from 'express'; +import { ChainOfDraftClient } from './client'; +import { AnalyticsService } from './analytics'; +import dotenv from 'dotenv'; +import { logger } from '../utils/logger'; + +dotenv.config(); +const app = express(); +const port = process.env.PORT || 3000; +app.use(express.json()); +const client = new ChainOfDraftClient(); +const analytics = new AnalyticsService(); + +app.post('/v1/completions', async (req, res) => { + try { + const { model, prompt, ...options } = req.body; + const result = await client.completions(model, prompt, options); + res.json(result); + } catch (error) { + logger.error('Error in completions:', error); + res.status(500).json({ error: error instanceof Error ? error.message : 'Unknown error' }); + } +}); + +app.post('/v1/chat/completions', async (req, res) => { + try { + const { model, messages, ...options } = req.body; + const result = await client.chat(model, messages, options); + res.json(result); + } catch (error) { + logger.error('Error in chat:', error); + res.status(500).json({ error: error instanceof Error ? error.message : 'Unknown error' }); + } +}); + +app.get('/v1/models', async (req, res) => { + try { + const models = await client.getAvailableModels(); + res.json({ data: models.map(id => ({ id })) }); + } catch (error) { + logger.error('Error getting models:', error); + res.status(500).json({ error: error instanceof Error ? error.message : 'Unknown error' }); + } +}); + +app.get('/v1/analytics/performance', async (req, res) => { + try { + const { domain } = req.query; + const stats = await analytics.getPerformanceByDomain(domain as string); + res.json({ data: stats }); + } catch (error) { + logger.error('Error getting performance stats:', error); + res.status(500).json({ error: error instanceof Error ? error.message : 'Unknown error' }); + } +}); + +app.get('/v1/analytics/token-reduction', async (req, res) => { + try { + const stats = await analytics.getTokenReductionStats(); + res.json({ data: stats }); + } catch (error) { + logger.error('Error getting token reduction stats:', error); + res.status(500).json({ error: error instanceof Error ? error.message : 'Unknown error' }); + } +}); + +app.get('/v1/analytics/accuracy', async (req, res) => { + try { + const stats = await analytics.getAccuracyComparison(); + res.json({ data: stats }); + } catch (error) { + logger.error('Error getting accuracy stats:', error); + res.status(500).json({ error: error instanceof Error ? error.message : 'Unknown error' }); + } +}); + +app.listen(port, () => { + logger.success(`Server running at http://localhost:${port}`); +}); diff --git a/write_files.sh b/write_files.sh new file mode 100755 index 0000000..a9bf588 --- /dev/null +++ b/write_files.sh @@ -0,0 +1 @@ +#!/bin/bash