mirror of
https://github.com/NVIDIA/dgx-spark-playbooks.git
synced 2026-04-24 10:53:52 +00:00
feat(processor): add parallel processing and NVIDIA API support
- Implement parallel chunk processing with configurable concurrency - Add direct NVIDIA API integration bypassing LangChain for better control - Optimize for DGX Spark unified memory with batch processing - Use concurrency of 4 for Ollama, 2 for other providers - Add proper error handling and user stop capability - Update NVIDIA model to Llama 3.3 Nemotron Super 49B v1.5 - Improve prompt engineering for triple extraction
This commit is contained in:
parent
12c4777eae
commit
23b5cbca4c
@ -138,11 +138,9 @@ export class TextProcessor {
|
|||||||
|
|
||||||
case 'nvidia':
|
case 'nvidia':
|
||||||
try {
|
try {
|
||||||
// Use the default Nemotron model for NVIDIA
|
// For NVIDIA, we'll use direct OpenAI client instead of LangChain
|
||||||
this.llm = await langChainService.getNemotronModel({
|
// This is handled in processText method
|
||||||
temperature: 0.1,
|
this.llm = null; // Set to null, will be handled differently
|
||||||
maxTokens: 8192
|
|
||||||
});
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Failed to initialize NVIDIA model:', error);
|
console.error('Failed to initialize NVIDIA model:', error);
|
||||||
throw new Error(`Failed to initialize NVIDIA model: ${error instanceof Error ? error.message : String(error)}`);
|
throw new Error(`Failed to initialize NVIDIA model: ${error instanceof Error ? error.message : String(error)}`);
|
||||||
@ -210,11 +208,16 @@ export class TextProcessor {
|
|||||||
await this.initialize();
|
await this.initialize();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure we have an LLM to extract triples
|
// For NVIDIA, use direct OpenAI client
|
||||||
|
if (this.selectedLLMProvider === 'nvidia') {
|
||||||
|
return await this.processTextWithNvidiaAPI(text);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure we have an LLM to extract triples for non-NVIDIA providers
|
||||||
if (!this.llm) {
|
if (!this.llm) {
|
||||||
const providerMessage = this.selectedLLMProvider === 'ollama'
|
const providerMessage = this.selectedLLMProvider === 'ollama'
|
||||||
? "Ollama server connection failed. Please ensure Ollama is running and accessible."
|
? "Ollama server connection failed. Please ensure Ollama is running and accessible."
|
||||||
: "NVIDIA API key is required. Please set NVIDIA_API_KEY in your environment variables.";
|
: "LLM configuration error";
|
||||||
throw new Error(`LLM configuration error: ${providerMessage}`);
|
throw new Error(`LLM configuration error: ${providerMessage}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -222,14 +225,100 @@ export class TextProcessor {
|
|||||||
const chunks = await this.chunkText(text);
|
const chunks = await this.chunkText(text);
|
||||||
console.log(`Split text into ${chunks.length} chunks`);
|
console.log(`Split text into ${chunks.length} chunks`);
|
||||||
|
|
||||||
// Step 2: Process each chunk to extract triples
|
// Step 2: Process chunks in parallel with controlled concurrency
|
||||||
|
// DGX Spark has unified memory, so we can prefetch batches into GPU before processing
|
||||||
|
const concurrency = this.selectedLLMProvider === 'ollama' ? 4 : 2; // Higher concurrency for local Ollama
|
||||||
|
const allTriples: Array<Triple & { confidence: number, metadata: any }> = [];
|
||||||
|
|
||||||
|
console.log(`Processing with concurrency: ${concurrency} (provider: ${this.selectedLLMProvider})`);
|
||||||
|
|
||||||
|
// Helper function to process a single chunk
|
||||||
|
const processChunk = async (chunk: string, index: number) => {
|
||||||
|
// Check if processing should be stopped
|
||||||
|
if (getShouldStopProcessing()) {
|
||||||
|
console.log(`Processing stopped by user at chunk ${index + 1}/${chunks.length}`);
|
||||||
|
resetStopProcessing();
|
||||||
|
throw new Error('Processing stopped by user');
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log(`Processing chunk ${index + 1}/${chunks.length} (${chunk.length} chars)`);
|
||||||
|
|
||||||
|
// Format the prompt with the chunk and parser instructions
|
||||||
|
const formatInstructions = this.tripleParser!.getFormatInstructions();
|
||||||
|
const prompt = await this.extractionTemplate!.format({
|
||||||
|
text: chunk,
|
||||||
|
format_instructions: formatInstructions
|
||||||
|
});
|
||||||
|
|
||||||
|
// Extract triples using the LLM
|
||||||
|
const response = await this.llm!.invoke(prompt);
|
||||||
|
const responseText = response.content as string;
|
||||||
|
const parsedTriples = await this.tripleParser!.parse(responseText);
|
||||||
|
|
||||||
|
return parsedTriples;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Process chunks in batches with controlled concurrency
|
||||||
|
for (let i = 0; i < chunks.length; i += concurrency) {
|
||||||
|
const batch = chunks.slice(i, i + concurrency);
|
||||||
|
const batchIndices = Array.from({ length: batch.length }, (_, idx) => i + idx);
|
||||||
|
|
||||||
|
console.log(`Processing batch ${Math.floor(i / concurrency) + 1}/${Math.ceil(chunks.length / concurrency)} (${batch.length} chunks in parallel)`);
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Process batch in parallel - GPU can prefetch next chunks while processing current ones
|
||||||
|
const results = await Promise.all(
|
||||||
|
batch.map((chunk, idx) => processChunk(chunk, batchIndices[idx]))
|
||||||
|
);
|
||||||
|
|
||||||
|
// Flatten and add to all triples
|
||||||
|
results.forEach((triples: Array<Triple & { confidence: number, metadata: any }>) => {
|
||||||
|
allTriples.push(...triples);
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Error processing batch:`, error);
|
||||||
|
// Continue with next batch instead of failing completely
|
||||||
|
if (error instanceof Error && error.message === 'Processing stopped by user') {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: Post-process to remove duplicates and normalize
|
||||||
|
const processedTriples = this.postProcessTriples(allTriples);
|
||||||
|
console.log(`Extracted ${processedTriples.length} unique triples after post-processing`);
|
||||||
|
|
||||||
|
return processedTriples;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Process text using NVIDIA API directly with OpenAI client (bypasses LangChain)
|
||||||
|
* @param text Text to process
|
||||||
|
* @returns Array of triples with metadata
|
||||||
|
*/
|
||||||
|
private async processTextWithNvidiaAPI(text: string): Promise<Array<Triple & { confidence: number, metadata: any }>> {
|
||||||
|
const apiKey = process.env.NVIDIA_API_KEY;
|
||||||
|
if (!apiKey) {
|
||||||
|
throw new Error('NVIDIA_API_KEY is required but not set');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize parser if needed
|
||||||
|
if (!this.tripleParser) {
|
||||||
|
await this.initialize();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 1: Chunk the text
|
||||||
|
const chunks = await this.chunkText(text);
|
||||||
|
console.log(`Split text into ${chunks.length} chunks`);
|
||||||
|
|
||||||
|
// Step 2: Process each chunk
|
||||||
const allTriples: Array<Triple & { confidence: number, metadata: any }> = [];
|
const allTriples: Array<Triple & { confidence: number, metadata: any }> = [];
|
||||||
|
|
||||||
for (let i = 0; i < chunks.length; i++) {
|
for (let i = 0; i < chunks.length; i++) {
|
||||||
// Check if processing should be stopped
|
// Check if processing should be stopped
|
||||||
if (getShouldStopProcessing()) {
|
if (getShouldStopProcessing()) {
|
||||||
console.log(`Processing stopped by user at chunk ${i + 1}/${chunks.length}`);
|
console.log(`Processing stopped by user at chunk ${i + 1}/${chunks.length}`);
|
||||||
resetStopProcessing(); // Reset the flag for next time
|
resetStopProcessing();
|
||||||
throw new Error('Processing stopped by user');
|
throw new Error('Processing stopped by user');
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -237,25 +326,62 @@ export class TextProcessor {
|
|||||||
console.log(`Processing chunk ${i + 1}/${chunks.length} (${chunk.length} chars)`);
|
console.log(`Processing chunk ${i + 1}/${chunks.length} (${chunk.length} chars)`);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Format the prompt with the chunk and parser instructions
|
// Create the prompt
|
||||||
const formatInstructions = this.tripleParser!.getFormatInstructions();
|
const formatInstructions = this.tripleParser!.getFormatInstructions();
|
||||||
const prompt = await this.extractionTemplate!.format({
|
const prompt = `You are a knowledge graph builder that extracts structured information from text.
|
||||||
text: chunk,
|
Extract subject-predicate-object triples from the following text.
|
||||||
format_instructions: formatInstructions
|
|
||||||
|
Guidelines:
|
||||||
|
- Extract only factual triples present in the text
|
||||||
|
- Normalize entity names to their canonical form
|
||||||
|
- Assign appropriate confidence scores (0-1)
|
||||||
|
- Include entity types in metadata
|
||||||
|
- For each triple, include a brief context from the source text
|
||||||
|
|
||||||
|
Text: ${chunk}
|
||||||
|
|
||||||
|
${formatInstructions}`;
|
||||||
|
|
||||||
|
// Call NVIDIA API directly using fetch
|
||||||
|
const response = await fetch('https://integrate.api.nvidia.com/v1/chat/completions', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Authorization': `Bearer ${apiKey}`
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
model: 'nvidia/llama-3.3-nemotron-super-49b-v1.5',
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
content: prompt
|
||||||
|
}
|
||||||
|
],
|
||||||
|
temperature: 0.1,
|
||||||
|
max_tokens: 8192,
|
||||||
|
top_p: 0.95
|
||||||
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
// Extract triples using the LLM
|
if (!response.ok) {
|
||||||
const response = await this.llm!.invoke(prompt);
|
const errorText = await response.text();
|
||||||
const responseText = response.content as string;
|
throw new Error(`NVIDIA API error (${response.status}): ${errorText}`);
|
||||||
const parsedTriples = await this.tripleParser!.parse(responseText);
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
const responseText = data.choices[0].message.content;
|
||||||
|
|
||||||
|
// Parse the response
|
||||||
|
const parsedTriples = await this.tripleParser!.parse(responseText);
|
||||||
allTriples.push(...parsedTriples);
|
allTriples.push(...parsedTriples);
|
||||||
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(`Error processing chunk ${i + 1}:`, error);
|
console.error(`Error processing chunk ${i + 1}:`, error);
|
||||||
|
throw error; // Re-throw to see the actual error
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 3: Post-process to remove duplicates and normalize
|
// Step 3: Post-process
|
||||||
const processedTriples = this.postProcessTriples(allTriples);
|
const processedTriples = this.postProcessTriples(allTriples);
|
||||||
console.log(`Extracted ${processedTriples.length} unique triples after post-processing`);
|
console.log(`Extracted ${processedTriples.length} unique triples after post-processing`);
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user