Skip to content
This repository has been archived by the owner on Dec 11, 2024. It is now read-only.

Commit

Permalink
[BugFix] adding initial_offset when using Aten runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
hikettei committed Jun 3, 2024
1 parent 22acca5 commit 212d30e
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 8 deletions.
6 changes: 5 additions & 1 deletion source/backends/aten/arithmetic.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
(impl AddNode +)
(impl SubNode -)
(impl MulNode *)
(impl DivNode /))
(impl DivNode /)
)


(define-impl (MoveTensorNode :device Aten :extends (AtenOp))
:forward
Expand All @@ -24,3 +26,5 @@
(setf (aten-bp self) (unary->aten-code (list x y) 1 #'(lambda (x y) `((setf ,x ,y))))
(aten-inputs self) (list x y)
(aten-outputs self) (list x)))))


1 change: 1 addition & 0 deletions source/backends/aten/codegen.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
(intern (symbol-name (tensor-id tensor)))
`(aref ,(intern (symbol-name (tensor-id tensor)))
(+
,(intern (tensor-initial-offset-name tensor))
,@(loop for dim upfrom 0 below (dims tensor)
for view in (tensor-view tensor)
for index in indices
Expand Down
22 changes: 17 additions & 5 deletions source/backends/aten/tensor.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
Base class for various aten backends.
"))

(defun tensor-initial-offset-name (tensor)
(format nil "_~a_offset" (tensor-id tensor)))

(defun rest->alist (rest)
(loop for i upfrom 0 below (length rest) by 2
if (not (find (nth i rest) `(:debug)))
Expand Down Expand Up @@ -55,6 +58,9 @@ Base class for various aten backends.
(let ((id2table (make-hash-table :test #'equal)))
(loop for name in (bp-deps (aten-bp op))
do (setf (gethash (idkey name) id2table) name))
(loop for tensor in (aten-inputs op)
for name = (idkey (tensor-initial-offset-name tensor))
do (setf (gethash name id2table) #'(lambda () (tensor-initial-offset tensor))))
(setf (aten-scalars op) id2table))
(aten/ir::make-composite
:inputs (remove-duplicates
Expand All @@ -63,6 +69,9 @@ Base class for various aten backends.
'list
(alexandria:compose #'aten/ir::%parse-aten #'tensor->shape-tracker)
(aten-inputs op))
(loop for tensor in (aten-inputs op)
collect
(aten/ir::%parse-aten (format nil "~a{Int}[]<>()" (tensor-initial-offset-name tensor))))
(loop for name in scalars
collect (aten/ir::%parse-aten (coerce (format nil "~a{Int}[]<>()" name) '(simple-array character (*))))))
:test #'equal)
Expand Down Expand Up @@ -110,15 +119,18 @@ Base class for various aten backends.
collect (position out (aten-inputs (wf/vm:wfop-node wfir)) :test #'equal))))
(setf (wf/vm:wfop-op wfir)
#'(lambda (&rest args)
;;(print "++++++")
;;(print (map 'list #'cl-waffe2/vm::maybe-observe-axis inputs))
;;(print (map 'list #'tensor-vec args))
;;(print (aten/ir:composite-code (aten/engine::cc-base-composite (aten-composite (wf/vm:wfop-node wfir)))))
(apply
(aten/engine::cc-caller (aten-composite (wf/vm:wfop-node wfir)))
(append
(map 'list #'tensor-vec args)
(map 'list #'cl-waffe2/vm::maybe-observe-axis inputs)))
(loop with c fixnum = 0
for val in inputs
if (functionp val)
collect (prog1
(tensor-initial-offset (nth c args))
(incf c))
else
collect (cl-waffe2/vm::maybe-observe-axis val))))
(apply
#'values
(loop for o in out-positions
Expand Down
3 changes: 2 additions & 1 deletion source/backends/aten/unary.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
(impl log2Node log2)
(impl log10Node log10)
(impl logeNode log)
(impl Log1pNode log1p))
(impl Log1pNode log1p)
)

(define-impl (ExptNode :device Aten :extends (AtenOp))
:forward ((self x out n)
Expand Down
3 changes: 2 additions & 1 deletion source/vm/vm.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ This parameter is useful for printing how all instructions are performed. If set
(s2a (actual-shape var)))
(setf (cl-waffe2/vm.generic-tensor::tensor-visible-shape place) s1a
(cl-waffe2/vm.generic-tensor::tensor-visible-shape var) s2a)
(unwind-protect (%vm-move place var)
(unwind-protect (%vm-move place var)
(setf (cl-waffe2/vm.generic-tensor::tensor-visible-shape place) s1
(cl-waffe2/vm.generic-tensor::tensor-visible-shape var) s2)))))))
nil)
Expand All @@ -79,6 +79,7 @@ This parameter is useful for printing how all instructions are performed. If set
(or (when state
(cl-waffe2/vm.generic-tensor::statecontainer-forward-result state))
tensor)))
;;(setf (tensor-initial-offset res) (tensor-initial-offset tensor))
(the AbstractTensor res)))))

(declaim (ftype (function (list list) t) write-result))
Expand Down

0 comments on commit 212d30e

Please sign in to comment.