Skip to content

Commit

Permalink
Merge pull request #462 from Hexastack/461-issue-saving-nlpsample-as-…
Browse files Browse the repository at this point in the history
…an-attachment

feat: import nlpsamples files without adding them as attachments
  • Loading branch information
marrouchi authored Dec 25, 2024
2 parents d143667 + 717d403 commit bb83cd5
Show file tree
Hide file tree
Showing 13 changed files with 503 additions and 246 deletions.
2 changes: 1 addition & 1 deletion api/src/config/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ export const config: Config = {
storageMode: 'disk',
maxUploadSize: process.env.UPLOAD_MAX_SIZE_IN_BYTES
? Number(process.env.UPLOAD_MAX_SIZE_IN_BYTES)
: 2000000,
: 50 * 1024 * 1024, // 50 MB in bytes
appName: 'Hexabot.ai',
},
pagination: {
Expand Down
100 changes: 25 additions & 75 deletions api/src/nlp/controllers/nlp-sample.controller.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,12 @@
* 2. All derivative works must include clear attribution to the original creator and software, Hexastack and Hexabot, in a prominent location (e.g., in the software's "About" section, documentation, and README file).
*/

import fs from 'fs';

import { CACHE_MANAGER } from '@nestjs/cache-manager';
import { BadRequestException, NotFoundException } from '@nestjs/common';
import { EventEmitter2 } from '@nestjs/event-emitter';
import { MongooseModule } from '@nestjs/mongoose';
import { Test, TestingModule } from '@nestjs/testing';

import { AttachmentRepository } from '@/attachment/repositories/attachment.repository';
import { AttachmentModel } from '@/attachment/schemas/attachment.schema';
import { AttachmentService } from '@/attachment/services/attachment.service';
import { HelperService } from '@/helper/helper.service';
import { LanguageRepository } from '@/i18n/repositories/language.repository';
import { Language, LanguageModel } from '@/i18n/schemas/language.schema';
Expand Down Expand Up @@ -50,7 +45,6 @@ import { NlpEntityService } from '../services/nlp-entity.service';
import { NlpSampleEntityService } from '../services/nlp-sample-entity.service';
import { NlpSampleService } from '../services/nlp-sample.service';
import { NlpValueService } from '../services/nlp-value.service';
import { NlpService } from '../services/nlp.service';

import { NlpSampleController } from './nlp-sample.controller';

Expand All @@ -60,7 +54,6 @@ describe('NlpSampleController', () => {
let nlpSampleService: NlpSampleService;
let nlpEntityService: NlpEntityService;
let nlpValueService: NlpValueService;
let attachmentService: AttachmentService;
let languageService: LanguageService;
let byeJhonSampleId: string;
let languages: Language[];
Expand All @@ -76,7 +69,6 @@ describe('NlpSampleController', () => {
MongooseModule.forFeature([
NlpSampleModel,
NlpSampleEntityModel,
AttachmentModel,
NlpEntityModel,
NlpValueModel,
SettingModel,
Expand All @@ -87,9 +79,7 @@ describe('NlpSampleController', () => {
LoggerService,
NlpSampleRepository,
NlpSampleEntityRepository,
AttachmentService,
NlpEntityService,
AttachmentRepository,
NlpEntityRepository,
NlpValueService,
NlpValueRepository,
Expand All @@ -98,7 +88,6 @@ describe('NlpSampleController', () => {
LanguageRepository,
LanguageService,
EventEmitter2,
NlpService,
HelperService,
SettingRepository,
SettingService,
Expand Down Expand Up @@ -131,7 +120,6 @@ describe('NlpSampleController', () => {
text: 'Bye Jhon',
})
).id;
attachmentService = module.get<AttachmentService>(AttachmentService);
languageService = module.get<LanguageService>(LanguageService);
languages = await languageService.findAll();
});
Expand Down Expand Up @@ -315,83 +303,44 @@ describe('NlpSampleController', () => {
});
});

describe('import', () => {
it('should throw exception when attachment is not found', async () => {
const invalidattachmentId = (
await attachmentService.findOne({
name: 'store2.jpg',
})
).id;
await attachmentService.deleteOne({ name: 'store2.jpg' });
await expect(
nlpSampleController.import(invalidattachmentId),
).rejects.toThrow(NotFoundException);
});

it('should throw exception when file location is not present', async () => {
const attachmentId = (
await attachmentService.findOne({
name: 'store1.jpg',
})
).id;
jest.spyOn(fs, 'existsSync').mockReturnValueOnce(false);
await expect(nlpSampleController.import(attachmentId)).rejects.toThrow(
NotFoundException,
describe('importFile', () => {
it('should throw exception when something is wrong with the upload', async () => {
const file = {
buffer: Buffer.from('', 'utf-8'),
size: 0,
mimetype: 'text/csv',
} as Express.Multer.File;
await expect(nlpSampleController.importFile(file)).rejects.toThrow(
'Bad Request Exception',
);
});

it('should return a failure if an error occurs when parsing csv file ', async () => {
const mockCsvDataWithErrors: string = `intent,entities,lang,question
greeting,person,en`;
jest.spyOn(fs, 'existsSync').mockReturnValueOnce(true);
jest.spyOn(fs, 'readFileSync').mockReturnValueOnce(mockCsvDataWithErrors);
const attachmentId = (
await attachmentService.findOne({
name: 'store1.jpg',
})
).id;

const mockParsedCsvDataWithErrors = {
data: [{ intent: 'greeting', entities: 'person', lang: 'en' }],
errors: [
{
type: 'FieldMismatch',
code: 'TooFewFields',
message: 'Too few fields: expected 4 fields but parsed 3',
row: 0,
},
],
meta: {
delimiter: ',',
linebreak: '\n',
aborted: false,
truncated: false,
cursor: 49,
fields: ['intent', 'entities', 'lang', 'question'],
},
};
await expect(nlpSampleController.import(attachmentId)).rejects.toThrow(
new BadRequestException({
cause: mockParsedCsvDataWithErrors.errors,
description: 'Error while parsing CSV',
}),
);
const buffer = Buffer.from(mockCsvDataWithErrors, 'utf-8');
const file = {
buffer,
size: buffer.length,
mimetype: 'text/csv',
} as Express.Multer.File;
await expect(nlpSampleController.importFile(file)).rejects.toThrow();
});

it('should import data from a CSV file', async () => {
const attachmentId = (
await attachmentService.findOne({
name: 'store1.jpg',
})
).id;
const mockCsvData: string = [
`text,intent,language`,
`How much does a BMW cost?,price,en`,
].join('\n');
jest.spyOn(fs, 'existsSync').mockReturnValueOnce(true);
jest.spyOn(fs, 'readFileSync').mockReturnValueOnce(mockCsvData);

const result = await nlpSampleController.import(attachmentId);
const buffer = Buffer.from(mockCsvData, 'utf-8');
const file = {
buffer,
size: buffer.length,
mimetype: 'text/csv',
} as Express.Multer.File;
const result = await nlpSampleController.importFile(file);
const intentEntityResult = await nlpEntityService.findOne({
name: 'intent',
});
Expand Down Expand Up @@ -429,9 +378,10 @@ describe('NlpSampleController', () => {
expect(intentEntityResult).toEqualPayload(intentEntity);
expect(priceValueResult).toEqualPayload(priceValue);
expect(textSampleResult).toEqualPayload(textSample);
expect(result).toEqual({ success: true });
expect(result).toEqualPayload([textSample]);
});
});

describe('deleteMany', () => {
it('should delete multiple nlp samples', async () => {
const samplesToDelete = [
Expand Down
142 changes: 9 additions & 133 deletions api/src/nlp/controllers/nlp-sample.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
* 2. All derivative works must include clear attribution to the original creator and software, Hexastack and Hexabot, in a prominent location (e.g., in the software's "About" section, documentation, and README file).
*/

import fs from 'fs';
import { join } from 'path';
import { Readable } from 'stream';

import {
Expand All @@ -25,14 +23,13 @@ import {
Query,
Res,
StreamableFile,
UploadedFile,
UseInterceptors,
} from '@nestjs/common';
import { FileInterceptor } from '@nestjs/platform-express';
import { CsrfCheck } from '@tekuconcept/nestjs-csrf';
import { Response } from 'express';
import Papa from 'papaparse';

import { AttachmentService } from '@/attachment/services/attachment.service';
import { config } from '@/config';
import { HelperService } from '@/helper/helper.service';
import { LanguageService } from '@/i18n/services/language.service';
import { CsrfInterceptor } from '@/interceptors/csrf.interceptor';
Expand All @@ -45,18 +42,17 @@ import { PopulatePipe } from '@/utils/pipes/populate.pipe';
import { SearchFilterPipe } from '@/utils/pipes/search-filter.pipe';
import { TFilterQuery } from '@/utils/types/filter.types';

import { NlpSampleCreateDto, NlpSampleDto } from '../dto/nlp-sample.dto';
import { NlpSampleDto } from '../dto/nlp-sample.dto';
import {
NlpSample,
NlpSampleFull,
NlpSamplePopulate,
NlpSampleStub,
} from '../schemas/nlp-sample.schema';
import { NlpSampleEntityValue, NlpSampleState } from '../schemas/types';
import { NlpSampleState } from '../schemas/types';
import { NlpEntityService } from '../services/nlp-entity.service';
import { NlpSampleEntityService } from '../services/nlp-sample-entity.service';
import { NlpSampleService } from '../services/nlp-sample.service';
import { NlpService } from '../services/nlp.service';

@UseInterceptors(CsrfInterceptor)
@Controller('nlpsample')
Expand All @@ -68,11 +64,9 @@ export class NlpSampleController extends BaseController<
> {
constructor(
private readonly nlpSampleService: NlpSampleService,
private readonly attachmentService: AttachmentService,
private readonly nlpSampleEntityService: NlpSampleEntityService,
private readonly nlpEntityService: NlpEntityService,
private readonly logger: LoggerService,
private readonly nlpService: NlpService,
private readonly languageService: LanguageService,
private readonly helperService: HelperService,
) {
Expand Down Expand Up @@ -369,129 +363,11 @@ export class NlpSampleController extends BaseController<
return deleteResult;
}

/**
* Imports NLP samples from a CSV file.
*
* @param file - The file path or ID of the CSV file to import.
*
* @returns A success message after the import process is completed.
*/
@CsrfCheck(true)
@Post('import/:file')
async import(
@Param('file')
file: string,
) {
// Check if file is present
const importedFile = await this.attachmentService.findOne(file);
if (!importedFile) {
throw new NotFoundException('Missing file!');
}
const filePath = importedFile
? join(config.parameters.uploadDir, importedFile.location)
: undefined;

// Check if file location is present
if (!fs.existsSync(filePath)) {
throw new NotFoundException('File does not exist');
}

const allEntities = await this.nlpEntityService.findAll();

// Check if file location is present
if (allEntities.length === 0) {
throw new NotFoundException(
'No entities found, please create them first.',
);
}

// Read file content
const data = fs.readFileSync(filePath, 'utf8');

// Parse local CSV file
const result: {
errors: any[];
data: Array<Record<string, string>>;
} = Papa.parse(data, {
header: true,
skipEmptyLines: true,
});

if (result.errors && result.errors.length > 0) {
this.logger.warn(
`Errors parsing the file: ${JSON.stringify(result.errors)}`,
);
throw new BadRequestException(result.errors, {
cause: result.errors,
description: 'Error while parsing CSV',
});
}
// Remove data with no intent
const filteredData = result.data.filter((d) => d.intent !== 'none');
const languages = await this.languageService.getLanguages();
const defaultLanguage = await this.languageService.getDefaultLanguage();
// Reduce function to ensure executing promises one by one
for (const d of filteredData) {
try {
// Check if a sample with the same text already exists
const existingSamples = await this.nlpSampleService.find({
text: d.text,
});

// Skip if sample already exists
if (Array.isArray(existingSamples) && existingSamples.length > 0) {
continue;
}

// Fallback to default language if 'language' is missing or invalid
if (!d.language || !(d.language in languages)) {
if (d.language) {
this.logger.warn(
`Language "${d.language}" does not exist, falling back to default.`,
);
}
d.language = defaultLanguage.code;
}

// Create a new sample dto
const sample: NlpSampleCreateDto = {
text: d.text,
trained: false,
language: languages[d.language].id,
};

// Create a new sample entity dto
const entities: NlpSampleEntityValue[] = allEntities
.filter(({ name }) => name in d)
.map(({ name }) => {
return {
entity: name,
value: d[name],
};
});

// Store any new entity/value
const storedEntities = await this.nlpEntityService.storeNewEntities(
sample.text,
entities,
['trait'],
);
// Store sample
const createdSample = await this.nlpSampleService.create(sample);
// Map and assign the sample ID to each stored entity
const sampleEntities = storedEntities.map((se) => ({
...se,
sample: createdSample?.id,
}));

// Store sample entities
await this.nlpSampleEntityService.createMany(sampleEntities);
} catch (err) {
this.logger.error('Error occurred when extracting data. ', err);
}
}

this.logger.log('Import process completed successfully.');
return { success: true };
@Post('import')
@UseInterceptors(FileInterceptor('file'))
async importFile(@UploadedFile() file: Express.Multer.File) {
const datasetContent = file.buffer.toString('utf-8');
return await this.nlpSampleService.parseAndSaveDataset(datasetContent);
}
}
Loading

0 comments on commit bb83cd5

Please sign in to comment.