Skip to content

Commit

Permalink
Merge pull request #129 from JunNishimura/#128
Browse files Browse the repository at this point in the history
add reset command
  • Loading branch information
JunNishimura committed Jun 18, 2023
2 parents 143afe2 + 4a24f06 commit bab1bc6
Show file tree
Hide file tree
Showing 11 changed files with 510 additions and 53 deletions.
37 changes: 1 addition & 36 deletions cmd/commit.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,44 +98,9 @@ func commit() error {
return nil
}

func getEntriesFromTree(rootName string, nodes []*object.Node) ([]*store.Entry, error) {
var entries []*store.Entry

for _, node := range nodes {
if len(node.Children) == 0 {
var entryName string
if rootName == "" {
entryName = node.Name
} else {
entryName = fmt.Sprintf("%s/%s", rootName, node.Name)
}
newEntry := &store.Entry{
Hash: node.Hash,
NameLength: uint16(len(entryName)),
Path: []byte(entryName),
}
entries = append(entries, newEntry)
} else {
var newRootName string
if rootName == "" {
newRootName = node.Name
} else {
newRootName = fmt.Sprintf("%s/%s", rootName, node.Name)
}
childEntries, err := getEntriesFromTree(newRootName, node.Children)
if err != nil {
return nil, err
}
entries = append(entries, childEntries...)
}
}

return entries, nil
}

func isIndexDifferentFromTree(index *store.Index, tree *object.Tree) (bool, error) {
rootName := ""
gotEntries, err := getEntriesFromTree(rootName, tree.Children)
gotEntries, err := store.GetEntriesFromTree(rootName, tree.Children)
if err != nil {
return false, err
}
Expand Down
142 changes: 142 additions & 0 deletions cmd/reset.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
Copyright © 2023 NAME HERE <EMAIL ADDRESS>
*/
package cmd

import (
"errors"
"fmt"
"regexp"
"strconv"
"strings"
"time"

"github.com/JunNishimura/Goit/internal/log"
"github.com/JunNishimura/Goit/internal/object"
"github.com/JunNishimura/Goit/internal/store"
"github.com/spf13/cobra"
)

var (
isSoft bool
isMixed bool
isHard bool
resetRegexp = regexp.MustCompile(`HEAD@\{\d\}`)
)

func resetHead(arg, rootGoitPath string, logRecord *store.LogRecord, head *store.Head, refs *store.Refs, conf *store.Config) error {
// reset Head
prevHeadHash := head.Commit.Hash
if err := head.Reset(rootGoitPath, refs, logRecord.Hash); err != nil {
return fmt.Errorf("fail to reset HEAD: %w", err)
}

// log
newRecord := log.NewRecord(log.ResetRecord, prevHeadHash, logRecord.Hash, conf.GetUserName(), conf.GetEmail(), time.Now(), fmt.Sprintf("moving to %s", arg))
if err := gLogger.WriteHEAD(newRecord); err != nil {
return fmt.Errorf("log error: %w", err)
}
if err := gLogger.WriteBranch(newRecord, head.Reference); err != nil {
return fmt.Errorf("log error: %w", err)
}

return nil
}

func resetIndex(rootGoitPath string, logRecord *store.LogRecord, index *store.Index) error {
// reset index
if err := index.Reset(rootGoitPath, logRecord.Hash); err != nil {
return fmt.Errorf("fail to reset index: %w", err)
}

return nil
}

func resetWorkingTree(rootGoitPath string, index *store.Index) error {
for _, entry := range index.Entries {
obj, err := object.GetObject(rootGoitPath, entry.Hash)
if err != nil {
return fmt.Errorf("fail to get object: %w", err)
}
if err := obj.ReflectToWorkingTree(rootGoitPath, string(entry.Path)); err != nil {
return fmt.Errorf("fail to reflect %s to working directory: %w", string(entry.Path), err)
}
}

return nil
}

// resetCmd represents the reset command
var resetCmd = &cobra.Command{
Use: "reset",
Short: "reset current HEAD to the specified state",
Long: "reset current HEAD to the specified state",
PreRunE: func(cmd *cobra.Command, args []string) error {
if client.RootGoitPath == "" {
return ErrGoitNotInitialized
}
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
// flag validation
if isSoft || isHard {
isMixed = false
}
if !((isSoft && !isMixed && !isHard) ||
(!isSoft && isMixed && !isHard) ||
(!isSoft && !isMixed && isHard)) {
return errors.New("invalid flags")
}

// args validation
if !(len(args) == 1 && resetRegexp.MatchString(args[0])) {
return errors.New("only one argument is acceptible. argument format is 'HEAD@{number}'")
}

// get log record
reflog, err := store.NewReflog(client.RootGoitPath, client.Head, client.Refs)
if err != nil {
return fmt.Errorf("fail to initialize reflog: %w", err)
}
sp := strings.Split(args[0], "HEAD@")[1]
headNum, err := strconv.Atoi(sp[1 : len(sp)-1])
if err != nil {
return fmt.Errorf("fail to convert number '%s': %w", args[0], err)
}
logRecord, err := reflog.GetRecord(headNum)
if err != nil {
return fmt.Errorf("fail to get log record: %w", err)
}

// reset HEAD
if isSoft || isMixed || isHard {
if err := resetHead(args[0], client.RootGoitPath, logRecord, client.Head, client.Refs, client.Conf); err != nil {
return fmt.Errorf("fail to reset HEAD: %w", err)
}
}

// reset index
if isMixed || isHard {
if err := resetIndex(client.RootGoitPath, logRecord, client.Idx); err != nil {
return fmt.Errorf("fail to reset index: %w", err)
}
}

// reset working tree
if isHard {
if err := resetWorkingTree(client.RootGoitPath, client.Idx); err != nil {
return fmt.Errorf("fail to reset working tree: %w", err)
}
}

return nil
},
}

func init() {
rootCmd.AddCommand(resetCmd)

resetCmd.Flags().BoolVar(&isSoft, "soft", false, "reset HEAD")
resetCmd.Flags().BoolVar(&isMixed, "mixed", true, "reset HEAD and index")
resetCmd.Flags().BoolVar(&isHard, "hard", false, "reset HEAD, index and working tree")
}
5 changes: 5 additions & 0 deletions internal/log/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ const (
CommitRecord
CheckoutRecord
BranchRecord
ResetRecord
)

func NewRecordType(typeString string) RecordType {
Expand All @@ -27,6 +28,8 @@ func NewRecordType(typeString string) RecordType {
return CheckoutRecord
case "branch":
return BranchRecord
case "reset":
return ResetRecord
default:
return UndefinedRecord
}
Expand All @@ -40,6 +43,8 @@ func (t RecordType) String() string {
return "checkout"
case BranchRecord:
return "branch"
case ResetRecord:
return "reset"
default:
return "undefined"
}
Expand Down
60 changes: 60 additions & 0 deletions internal/log/logger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,37 @@ func TestNewRecord(t *testing.T) {
},
}
}(),
func() *test {
hash, _ := hex.DecodeString("87f3c49bccf2597484ece08746d3ee5defaba335")
now := time.Now()
unixtime := fmt.Sprint(now.Unix())
_, offset := now.Zone()
offsetMinutes := offset / 60
timeDiff := fmt.Sprintf("%+03d%02d", offsetMinutes/60, offsetMinutes%60)

return &test{
name: "success: reset record",
args: args{
recType: ResetRecord,
from: sha.SHA1(hash),
to: sha.SHA1(hash),
name: "Test Taro",
email: "test@example.com",
t: now,
message: "test",
},
want: &record{
recType: ResetRecord,
from: sha.SHA1(hash),
to: sha.SHA1(hash),
name: "Test Taro",
email: "test@example.com",
unixtime: unixtime,
timeDiff: timeDiff,
message: "test",
},
}
}(),
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -185,6 +216,20 @@ func TestWriteHEAD(t *testing.T) {
wantErr: false,
}
}(),
func() *test {
hash, _ := hex.DecodeString("87f3c49bccf2597484ece08746d3ee5defaba335")
now := time.Now()
rec := NewRecord(ResetRecord, hash, hash, "Test Taro", "test@example.com", now, "test")

return &test{
name: "success: reset record",
args: args{
rec: rec,
},
want: fmt.Sprintf("%s %s %s <%s> %s %s\t%s: %s\n", rec.from, rec.to, rec.name, rec.email, rec.unixtime, rec.timeDiff, rec.recType, rec.message),
wantErr: false,
}
}(),
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -265,6 +310,21 @@ func TestWriteBranch(t *testing.T) {
wantErr: false,
}
}(),
func() *test {
hash, _ := hex.DecodeString("87f3c49bccf2597484ece08746d3ee5defaba335")
now := time.Now()
rec := NewRecord(ResetRecord, hash, hash, "Test Taro", "test@example.com", now, "test")

return &test{
name: "success: reset record",
args: args{
rec: rec,
branchName: "test",
},
want: fmt.Sprintf("%s %s %s <%s> %s %s\t%s: %s\n", rec.from, rec.to, rec.name, rec.email, rec.unixtime, rec.timeDiff, rec.recType, rec.message),
wantErr: false,
}
}(),
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
16 changes: 16 additions & 0 deletions internal/object/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,19 @@ func (o *Object) Write(rootGoitPath string) error {
}
return nil
}

func (o *Object) ReflectToWorkingTree(rootGoitPath, path string) error {
rootDir := filepath.Dir(rootGoitPath)
filePath := filepath.Join(rootDir, path)
f, err := os.Create(filePath)
if err != nil {
return fmt.Errorf("fail to create file %s: %w", filePath, err)
}
defer f.Close()

if _, err := f.Write(o.Data); err != nil {
return fmt.Errorf("fail to write object to %s: %w", filePath, err)
}

return nil
}
67 changes: 67 additions & 0 deletions internal/object/object_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,70 @@ func TestWrite(t *testing.T) {
})
}
}

func TestReflectToWorkingTree(t *testing.T) {
type args struct {
path string
}
type fields struct {
data string
}
type test struct {
name string
args args
fields fields
want string
wantErr bool
}
tests := []*test{
func() *test {
return &test{
name: "success",
args: args{
path: "test.txt",
},
fields: fields{
data: "hello, world",
},
want: "hello, world",
wantErr: false,
}
}(),
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpDir := t.TempDir()
// .goit initialization
goitDir := filepath.Join(tmpDir, ".goit")
if err := os.Mkdir(goitDir, os.ModePerm); err != nil {
t.Logf("%v: %s", err, goitDir)
}
// make .goit/objects directory
objectsDir := filepath.Join(goitDir, "objects")
if err := os.Mkdir(objectsDir, os.ModePerm); err != nil {
t.Logf("%v: %s", err, objectsDir)
}

obj, err := NewObject(BlobObject, []byte(tt.fields.data))
if err != nil {
t.Log(err)
}
if err := obj.Write(goitDir); err != nil {
t.Log(err)
}

if err := obj.ReflectToWorkingTree(goitDir, tt.args.path); (err != nil) != tt.wantErr {
t.Errorf("got = %v, want = %v", err, tt.wantErr)
}

filePath := filepath.Join(tmpDir, tt.args.path)
got, err := os.ReadFile(filePath)
if err != nil {
t.Log(err)
}
if string(got) != tt.want {
t.Errorf("got = %s, want = %s", string(got), tt.want)
}
})
}
}
Loading

0 comments on commit bab1bc6

Please sign in to comment.