Skip to content
This repository has been archived by the owner on Dec 1, 2022. It is now read-only.

Commit

Permalink
Merge pull request #6 from ZondaX/fix/resources
Browse files Browse the repository at this point in the history
Improve resource handling
  • Loading branch information
jleni authored Feb 1, 2019
2 parents 918a8a7 + 7abf074 commit 69c15f1
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 46 deletions.
6 changes: 3 additions & 3 deletions apduWrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func WrapCommandAPDU(
}

// UnwrapResponseAPDU parses a response of 64 byte packets into the real data
func UnwrapResponseAPDU(channel uint16, pipe <- chan []byte, packetSize int) ([]byte, error) {
func UnwrapResponseAPDU(channel uint16, pipe <-chan []byte, packetSize int) ([]byte, error) {
var sequenceIdx uint16

var totalResult []byte
Expand All @@ -135,7 +135,7 @@ func UnwrapResponseAPDU(channel uint16, pipe <- chan []byte, packetSize int) ([]

for !done {
// Read next packet from the channel
buffer := <- pipe
buffer := <-pipe

result, responseSize, err := DeserializePacket(channel, buffer, sequenceIdx)
if err != nil {
Expand All @@ -157,4 +157,4 @@ func UnwrapResponseAPDU(channel uint16, pipe <- chan []byte, packetSize int) ([]
// Remove trailing zeros
totalResult = totalResult[:totalSize]
return totalResult, nil
}
}
40 changes: 20 additions & 20 deletions apduWrapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
)

func Test_SerializePacket_EmptyCommand(t *testing.T) {
var command= make([]byte, 1)
var command = make([]byte, 1)

_, _, err := SerializePacket(0x0101, command, 64, 0)
assert.Nil(t, err, "Commands smaller than 3 bytes should return error")
Expand All @@ -42,9 +42,9 @@ func Test_SerializePacket_PacketSize(t *testing.T) {
commandLen uint16
}

h := header{channel: 0x0101, tag: 0x05, sequenceIdx:0, commandLen: 32}
h := header{channel: 0x0101, tag: 0x05, sequenceIdx: 0, commandLen: 32}

var command= make([]byte, h.commandLen)
var command = make([]byte, h.commandLen)

result, _, _ := SerializePacket(
h.channel,
Expand All @@ -65,9 +65,9 @@ func Test_SerializePacket_Header(t *testing.T) {
commandLen uint16
}

h := header{channel: 0x0101, tag: 0x05, sequenceIdx:0, commandLen: 32}
h := header{channel: 0x0101, tag: 0x05, sequenceIdx: 0, commandLen: 32}

var command= make([]byte, h.commandLen)
var command = make([]byte, h.commandLen)

result, _, _ := SerializePacket(
h.channel,
Expand All @@ -91,17 +91,17 @@ func Test_SerializePacket_Offset(t *testing.T) {
commandLen uint16
}

h := header{channel: 0x0101, tag: 0x05, sequenceIdx:0, commandLen: 100}
h := header{channel: 0x0101, tag: 0x05, sequenceIdx: 0, commandLen: 100}

var command= make([]byte, h.commandLen)
var command = make([]byte, h.commandLen)

_, offset, _ := SerializePacket(
h.channel,
command,
packetSize,
h.sequenceIdx)

assert.Equal(t, packetSize - int(unsafe.Sizeof(h))+1, offset, "Wrong offset returned. Offset must point to the next comamnd byte that needs to be packet-ized.")
assert.Equal(t, packetSize-int(unsafe.Sizeof(h))+1, offset, "Wrong offset returned. Offset must point to the next comamnd byte that needs to be packet-ized.")
}

func Test_WrapCommandAPDU_NumberOfPackets(t *testing.T) {
Expand All @@ -119,9 +119,9 @@ func Test_WrapCommandAPDU_NumberOfPackets(t *testing.T) {
tag uint8
}

h1 := firstHeader{channel: 0x0101, tag: 0x05, sequenceIdx:0, commandLen: 100}
h1 := firstHeader{channel: 0x0101, tag: 0x05, sequenceIdx: 0, commandLen: 100}

var command= make([]byte, h1.commandLen)
var command = make([]byte, h1.commandLen)

result, _ := WrapCommandAPDU(
h1.channel,
Expand All @@ -146,9 +146,9 @@ func Test_WrapCommandAPDU_CheckHeaders(t *testing.T) {
tag uint8
}

h1 := firstHeader{channel: 0x0101, tag: 0x05, sequenceIdx:0, commandLen: 100}
h1 := firstHeader{channel: 0x0101, tag: 0x05, sequenceIdx: 0, commandLen: 100}

var command= make([]byte, h1.commandLen)
var command = make([]byte, h1.commandLen)

result, _ := WrapCommandAPDU(
h1.channel,
Expand Down Expand Up @@ -181,9 +181,9 @@ func Test_WrapCommandAPDU_CheckData(t *testing.T) {
tag uint8
}

h1 := firstHeader{channel: 0x0101, tag: 0x05, sequenceIdx:0, commandLen: 200}
h1 := firstHeader{channel: 0x0101, tag: 0x05, sequenceIdx: 0, commandLen: 200}

var command= make([]byte, h1.commandLen)
var command = make([]byte, h1.commandLen)

for i := range command {
command[i] = byte(i % 256)
Expand Down Expand Up @@ -228,9 +228,9 @@ func Test_DeserializePacket_FirstPacket(t *testing.T) {

output, totalSize, err := DeserializePacket(0x0101, packet, 0)

assert.Nil(t,err, "Simple deserialize should not have errors")
assert.Nil(t, err, "Simple deserialize should not have errors")
assert.Equal(t, len(sampleCommand), int(totalSize), "TotalSize is incorrect")
assert.Equal(t, packetSize - firstPacketHeaderSize, len(output), "Size of the deserialized packet is wrong")
assert.Equal(t, packetSize-firstPacketHeaderSize, len(output), "Size of the deserialized packet is wrong")
assert.True(t, bytes.Compare(output[:len(sampleCommand)], sampleCommand) == 0, "Deserialized message does not match the original")
}

Expand All @@ -243,9 +243,9 @@ func Test_DeserializePacket_SecondMessage(t *testing.T) {

output, totalSize, err := DeserializePacket(0x0101, packet, 1)

assert.Nil(t,err, "Simple deserialize should not have errors")
assert.Nil(t, err, "Simple deserialize should not have errors")
assert.Equal(t, 0, int(totalSize), "TotalSize should not be returned from deserialization of non-first packet")
assert.Equal(t, packetSize - firstPacketHeaderSize, len(output), "Size of the deserialized packet is wrong")
assert.Equal(t, packetSize-firstPacketHeaderSize, len(output), "Size of the deserialized packet is wrong")
assert.True(t, bytes.Compare(output[:len(sampleCommand)], sampleCommand) == 0, "Deserialized message does not match the original")
}

Expand All @@ -256,15 +256,15 @@ func Test_UnwrapApdu_SmokeTest(t *testing.T) {
var packetSize int = 64

// Initialize some dummy input
var input= make([]byte, inputSize)
var input = make([]byte, inputSize)
for i := range input {
input[i] = byte(i % 256)
}

serialized, _ := WrapCommandAPDU(channel, input, packetSize)

// Allocate enough buffers to keep all the packets
pipe := make(chan []byte, int(math.Ceil(float64(inputSize) / float64(packetSize))))
pipe := make(chan []byte, int(math.Ceil(float64(inputSize)/float64(packetSize))))
// Send all the packets to the pipe
for len(serialized) > 0 {
pipe <- serialized[:packetSize]
Expand Down
29 changes: 14 additions & 15 deletions ledger.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ package ledger_go
import (
"errors"
"fmt"
"github.com/zondax/hid"
"sync"

"github.com/zondax/hid"
)

const (
Expand All @@ -34,7 +35,7 @@ const (
type Ledger struct {
device hid.Device
readCo sync.Once
readChannel chan [] byte
readChannel chan []byte
Logging bool
}

Expand Down Expand Up @@ -70,23 +71,17 @@ func FindLedger() (*Ledger, error) {
devices := hid.Enumerate(VendorLedger, 0)

for _, d := range devices {
if d.VendorID == VendorLedger && d.UsagePage == UsagePageLedger {
device, err := d.Open()
if err != nil {
return nil, err
}
return NewLedger(device), nil
}
deviceFound := d.UsagePage == UsagePageLedger
deviceFound = deviceFound || (d.Product == "Nano S" && d.Interface == 0)

// Linux discovery
if d.VendorID == VendorLedger && d.Product == "Nano S" && d.Interface == 0 {
if deviceFound {
device, err := d.Open()
if err != nil {
return nil, err
if err == nil {
return NewLedger(device), nil
}
return NewLedger(device), nil
}
}

return nil, errors.New("no ledger connected")
}

Expand Down Expand Up @@ -126,6 +121,10 @@ func ErrorMessage(errorCode uint16) string {
}
}

func (ledger *Ledger) Close() error {
return ledger.device.Close()
}

func (ledger *Ledger) Write(buffer []byte) (int, error) {
totalBytes := len(buffer)
totalWrittenBytes := 0
Expand All @@ -150,7 +149,7 @@ func (ledger *Ledger) Read() <-chan []byte {
return ledger.readChannel
}

func (ledger *Ledger) initReadChannel(){
func (ledger *Ledger) initReadChannel() {
ledger.readChannel = make(chan []byte, 30)
go ledger.readThread()
}
Expand Down
16 changes: 8 additions & 8 deletions ledger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ package ledger_go
import (
"encoding/hex"
"fmt"
"github.com/zondax/hid"
"github.com/stretchr/testify/assert"
"github.com/zondax/hid"
"testing"
)

Expand All @@ -41,7 +41,7 @@ func Test_FindLedger(t *testing.T) {
fmt.Println("\n*********************************")
fmt.Println("Did you enter the password??")
fmt.Println("*********************************")
t.Fatalf( "Error: %s", err.Error())
t.Fatalf("Error: %s", err.Error())
}
assert.NotNil(t, ledger)
}
Expand All @@ -52,7 +52,7 @@ func Test_BasicExchange(t *testing.T) {
fmt.Println("\n*********************************")
fmt.Println("Did you enter the password??")
fmt.Println("*********************************")
t.Fatalf( "Error: %s", err.Error())
t.Fatalf("Error: %s", err.Error())
}
assert.NotNil(t, ledger)

Expand All @@ -63,7 +63,7 @@ func Test_BasicExchange(t *testing.T) {

if err != nil {
fmt.Printf("iteration %d\n", i)
t.Fatalf( "Error: %s", err.Error())
t.Fatalf("Error: %s", err.Error())
}

assert.Equal(t, 4, len(response))
Expand All @@ -76,23 +76,23 @@ func Test_LongExchange(t *testing.T) {
fmt.Println("\n*********************************")
fmt.Println("Did you enter the password??")
fmt.Println("*********************************")
t.Fatalf( "Error: %s", err.Error())
t.Fatalf("Error: %s", err.Error())
}
assert.NotNil(t, ledger)

path := "052c000080760000800000008000000000000000000000000000000000000000000000000000000000";
path := "052c000080760000800000008000000000000000000000000000000000000000000000000000000000"
pathBytes, err := hex.DecodeString(path)
if err != nil {
t.Fatalf("invalid path in test")
}

header := []byte { 0x55, 1, 0, 0, byte(len(pathBytes))}
header := []byte{0x55, 1, 0, 0, byte(len(pathBytes))}
message := append(header, pathBytes...)

response, err := ledger.Exchange(message)

if err != nil {
t.Fatalf( "Error: %s", err.Error())
t.Fatalf("Error: %s", err.Error())
}

assert.Equal(t, 65, len(response))
Expand Down

0 comments on commit 69c15f1

Please sign in to comment.