mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-04-22 18:13:52 +00:00
feat(processor): add parallel processing and NVIDIA API support
- Implement parallel chunk processing with configurable concurrency - Add direct NVIDIA API integration bypassing LangChain for better control - Optimize for DGX Spark unified memory with batch processing - Use concurrency of 4 for Ollama, 2 for other providers - Add proper error handling and user stop capability - Update NVIDIA model to Llama 3.3 Nemotron Super 49B v1.5 - Improve prompt engineering for triple extraction
This commit is contained in:
parent
12c4777eae
commit
23b5cbca4c
@ -138,11 +138,9 @@ export class TextProcessor {
|
||||
|
||||
case 'nvidia':
|
||||
try {
|
||||
// Use the default Nemotron model for NVIDIA
|
||||
this.llm = await langChainService.getNemotronModel({
|
||||
temperature: 0.1,
|
||||
maxTokens: 8192
|
||||
});
|
||||
// For NVIDIA, we'll use direct OpenAI client instead of LangChain
|
||||
// This is handled in processText method
|
||||
this.llm = null; // Set to null, will be handled differently
|
||||
} catch (error) {
|
||||
console.error('Failed to initialize NVIDIA model:', error);
|
||||
throw new Error(`Failed to initialize NVIDIA model: ${error instanceof Error ? error.message : String(error)}`);
|
||||
@ -210,11 +208,16 @@ export class TextProcessor {
|
||||
await this.initialize();
|
||||
}
|
||||
|
||||
// Ensure we have an LLM to extract triples
|
||||
// For NVIDIA, use direct OpenAI client
|
||||
if (this.selectedLLMProvider === 'nvidia') {
|
||||
return await this.processTextWithNvidiaAPI(text);
|
||||
}
|
||||
|
||||
// Ensure we have an LLM to extract triples for non-NVIDIA providers
|
||||
if (!this.llm) {
|
||||
const providerMessage = this.selectedLLMProvider === 'ollama'
|
||||
? "Ollama server connection failed. Please ensure Ollama is running and accessible."
|
||||
: "NVIDIA API key is required. Please set NVIDIA_API_KEY in your environment variables.";
|
||||
: "LLM configuration error";
|
||||
throw new Error(`LLM configuration error: ${providerMessage}`);
|
||||
}
|
||||
|
||||
@ -222,14 +225,100 @@ export class TextProcessor {
|
||||
const chunks = await this.chunkText(text);
|
||||
console.log(`Split text into ${chunks.length} chunks`);
|
||||
|
||||
// Step 2: Process each chunk to extract triples
|
||||
// Step 2: Process chunks in parallel with controlled concurrency
|
||||
// DGX Spark has unified memory, so we can prefetch batches into GPU before processing
|
||||
const concurrency = this.selectedLLMProvider === 'ollama' ? 4 : 2; // Higher concurrency for local Ollama
|
||||
const allTriples: Array<Triple & { confidence: number, metadata: any }> = [];
|
||||
|
||||
console.log(`Processing with concurrency: ${concurrency} (provider: ${this.selectedLLMProvider})`);
|
||||
|
||||
// Helper function to process a single chunk
|
||||
const processChunk = async (chunk: string, index: number) => {
|
||||
// Check if processing should be stopped
|
||||
if (getShouldStopProcessing()) {
|
||||
console.log(`Processing stopped by user at chunk ${index + 1}/${chunks.length}`);
|
||||
resetStopProcessing();
|
||||
throw new Error('Processing stopped by user');
|
||||
}
|
||||
|
||||
console.log(`Processing chunk ${index + 1}/${chunks.length} (${chunk.length} chars)`);
|
||||
|
||||
// Format the prompt with the chunk and parser instructions
|
||||
const formatInstructions = this.tripleParser!.getFormatInstructions();
|
||||
const prompt = await this.extractionTemplate!.format({
|
||||
text: chunk,
|
||||
format_instructions: formatInstructions
|
||||
});
|
||||
|
||||
// Extract triples using the LLM
|
||||
const response = await this.llm!.invoke(prompt);
|
||||
const responseText = response.content as string;
|
||||
const parsedTriples = await this.tripleParser!.parse(responseText);
|
||||
|
||||
return parsedTriples;
|
||||
};
|
||||
|
||||
// Process chunks in batches with controlled concurrency
|
||||
for (let i = 0; i < chunks.length; i += concurrency) {
|
||||
const batch = chunks.slice(i, i + concurrency);
|
||||
const batchIndices = Array.from({ length: batch.length }, (_, idx) => i + idx);
|
||||
|
||||
console.log(`Processing batch ${Math.floor(i / concurrency) + 1}/${Math.ceil(chunks.length / concurrency)} (${batch.length} chunks in parallel)`);
|
||||
|
||||
try {
|
||||
// Process batch in parallel - GPU can prefetch next chunks while processing current ones
|
||||
const results = await Promise.all(
|
||||
batch.map((chunk, idx) => processChunk(chunk, batchIndices[idx]))
|
||||
);
|
||||
|
||||
// Flatten and add to all triples
|
||||
results.forEach((triples: Array<Triple & { confidence: number, metadata: any }>) => {
|
||||
allTriples.push(...triples);
|
||||
});
|
||||
} catch (error) {
|
||||
console.error(`Error processing batch:`, error);
|
||||
// Continue with next batch instead of failing completely
|
||||
if (error instanceof Error && error.message === 'Processing stopped by user') {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Post-process to remove duplicates and normalize
|
||||
const processedTriples = this.postProcessTriples(allTriples);
|
||||
console.log(`Extracted ${processedTriples.length} unique triples after post-processing`);
|
||||
|
||||
return processedTriples;
|
||||
}
|
||||
|
||||
/**
|
||||
* Process text using NVIDIA API directly with OpenAI client (bypasses LangChain)
|
||||
* @param text Text to process
|
||||
* @returns Array of triples with metadata
|
||||
*/
|
||||
private async processTextWithNvidiaAPI(text: string): Promise<Array<Triple & { confidence: number, metadata: any }>> {
|
||||
const apiKey = process.env.NVIDIA_API_KEY;
|
||||
if (!apiKey) {
|
||||
throw new Error('NVIDIA_API_KEY is required but not set');
|
||||
}
|
||||
|
||||
// Initialize parser if needed
|
||||
if (!this.tripleParser) {
|
||||
await this.initialize();
|
||||
}
|
||||
|
||||
// Step 1: Chunk the text
|
||||
const chunks = await this.chunkText(text);
|
||||
console.log(`Split text into ${chunks.length} chunks`);
|
||||
|
||||
// Step 2: Process each chunk
|
||||
const allTriples: Array<Triple & { confidence: number, metadata: any }> = [];
|
||||
|
||||
for (let i = 0; i < chunks.length; i++) {
|
||||
// Check if processing should be stopped
|
||||
if (getShouldStopProcessing()) {
|
||||
console.log(`Processing stopped by user at chunk ${i + 1}/${chunks.length}`);
|
||||
resetStopProcessing(); // Reset the flag for next time
|
||||
resetStopProcessing();
|
||||
throw new Error('Processing stopped by user');
|
||||
}
|
||||
|
||||
@ -237,25 +326,62 @@ export class TextProcessor {
|
||||
console.log(`Processing chunk ${i + 1}/${chunks.length} (${chunk.length} chars)`);
|
||||
|
||||
try {
|
||||
// Format the prompt with the chunk and parser instructions
|
||||
// Create the prompt
|
||||
const formatInstructions = this.tripleParser!.getFormatInstructions();
|
||||
const prompt = await this.extractionTemplate!.format({
|
||||
text: chunk,
|
||||
format_instructions: formatInstructions
|
||||
const prompt = `You are a knowledge graph builder that extracts structured information from text.
|
||||
Extract subject-predicate-object triples from the following text.
|
||||
|
||||
Guidelines:
|
||||
- Extract only factual triples present in the text
|
||||
- Normalize entity names to their canonical form
|
||||
- Assign appropriate confidence scores (0-1)
|
||||
- Include entity types in metadata
|
||||
- For each triple, include a brief context from the source text
|
||||
|
||||
Text: ${chunk}
|
||||
|
||||
${formatInstructions}`;
|
||||
|
||||
// Call NVIDIA API directly using fetch
|
||||
const response = await fetch('https://integrate.api.nvidia.com/v1/chat/completions', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': `Bearer ${apiKey}`
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: 'nvidia/llama-3.3-nemotron-super-49b-v1.5',
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: prompt
|
||||
}
|
||||
],
|
||||
temperature: 0.1,
|
||||
max_tokens: 8192,
|
||||
top_p: 0.95
|
||||
})
|
||||
});
|
||||
|
||||
// Extract triples using the LLM
|
||||
const response = await this.llm!.invoke(prompt);
|
||||
const responseText = response.content as string;
|
||||
const parsedTriples = await this.tripleParser!.parse(responseText);
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(`NVIDIA API error (${response.status}): ${errorText}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
const responseText = data.choices[0].message.content;
|
||||
|
||||
// Parse the response
|
||||
const parsedTriples = await this.tripleParser!.parse(responseText);
|
||||
allTriples.push(...parsedTriples);
|
||||
|
||||
} catch (error) {
|
||||
console.error(`Error processing chunk ${i + 1}:`, error);
|
||||
throw error; // Re-throw to see the actual error
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Post-process to remove duplicates and normalize
|
||||
// Step 3: Post-process
|
||||
const processedTriples = this.postProcessTriples(allTriples);
|
||||
console.log(`Extracted ${processedTriples.length} unique triples after post-processing`);
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user