Skip to content

Commit

Permalink
Output logits
Browse files Browse the repository at this point in the history
  • Loading branch information
platypii committed Jul 24, 2024
1 parent 9e7fec8 commit 98d4e6d
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 12 deletions.
36 changes: 24 additions & 12 deletions src/hylang.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,37 @@ function textToFeatures(text) {
}

/**
* Classify the input text based on features
* Calculate logits for language detection
* @param {string} text - input text
* @returns {string} predicted programming language
* @returns {number[]} logits for each language
*/
export function detectLanguage(text) {
export function detectLanguageLogits(text) {
const features = textToFeatures(text)

// Calculate the logits for each class
const logits = weights.map((weightsRow, i) => {
const weightedSum = weightsRow.reduce((sum, weight, j) => sum + weight * features[j], biases[i])
return weightedSum
return weights.map((weightsRow, i) => {
return weightsRow.reduce((sum, weight, j) => sum + weight * features[j], biases[i])
})
}

// Apply softmax to logits to get probabilities
const expLogits = logits.map(logit => Math.exp(logit))
const sumExpLogits = expLogits.reduce((sum, expLogit) => sum + expLogit, 0)
const probabilities = expLogits.map(expLogit => expLogit / sumExpLogits)
/**
* Apply softmax to an array of numbers
* @param {number[]} arr - input array
* @returns {number[]} softmax probabilities
*/
function softmax(arr) {
const expArr = arr.map(Math.exp)
const sumExpArr = expArr.reduce((sum, exp) => sum + exp, 0)
return expArr.map(exp => exp / sumExpArr)
}

// Find the class with the highest probability
/**
* Classify the input text based on features
* @param {string} text - input text
* @returns {string} predicted programming language
*/
export function detectLanguage(text) {
const logits = detectLanguageLogits(text)
const probabilities = softmax(logits)
const predictedClass = probabilities.indexOf(Math.max(...probabilities))
return languages[predictedClass]
}
44 changes: 44 additions & 0 deletions test/logits.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import { describe, expect, it } from 'vitest'
import { detectLanguageLogits } from '../src/hylang.js'

describe('detectLanguageLogits', () => {
it('should return an array of logits for JavaScript', () => {
const jsCode = `
function greet(name) {
console.log('Hello, ' + name + '!')
}
const numbers = [1, 2, 3, 4, 5]
let sum = 0
for (let i = 0; i < numbers.length; i++) {
sum += numbers[i]
}
console.log('Sum:', sum)
`

const logits = detectLanguageLogits(jsCode)
expect(logits).toBeInstanceOf(Array)
expect(logits.length).toBeGreaterThan(0)
expect(logits.every(logit => typeof logit === 'number')).toBe(true)
expect(Math.max(...logits)).toBeGreaterThan(0)
})

it('should return different logits for different inputs', () => {
const jsCode = `
function test() {
return 'Hello, World!';
}
`

const pythonCode = `
def test():
return 'Hello, World!'
`

const jsLogits = detectLanguageLogits(jsCode)
const pythonLogits = detectLanguageLogits(pythonCode)
expect(jsLogits).not.toEqual(pythonLogits)
})
})

0 comments on commit 98d4e6d

Please sign in to comment.