diff --git a/source/backends/aten/arithmetic.lisp b/source/backends/aten/arithmetic.lisp index fe5a425c..c6d1d07d 100644 --- a/source/backends/aten/arithmetic.lisp +++ b/source/backends/aten/arithmetic.lisp @@ -14,7 +14,9 @@ (impl AddNode +) (impl SubNode -) (impl MulNode *) - (impl DivNode /)) + (impl DivNode /) + ) + (define-impl (MoveTensorNode :device Aten :extends (AtenOp)) :forward @@ -24,3 +26,5 @@ (setf (aten-bp self) (unary->aten-code (list x y) 1 #'(lambda (x y) `((setf ,x ,y)))) (aten-inputs self) (list x y) (aten-outputs self) (list x))))) + + diff --git a/source/backends/aten/codegen.lisp b/source/backends/aten/codegen.lisp index 5b57f9d5..dde916a4 100644 --- a/source/backends/aten/codegen.lisp +++ b/source/backends/aten/codegen.lisp @@ -56,6 +56,7 @@ (intern (symbol-name (tensor-id tensor))) `(aref ,(intern (symbol-name (tensor-id tensor))) (+ + ,(intern (tensor-initial-offset-name tensor)) ,@(loop for dim upfrom 0 below (dims tensor) for view in (tensor-view tensor) for index in indices diff --git a/source/backends/aten/tensor.lisp b/source/backends/aten/tensor.lisp index 122a1c81..c540de6c 100644 --- a/source/backends/aten/tensor.lisp +++ b/source/backends/aten/tensor.lisp @@ -16,6 +16,9 @@ Base class for various aten backends. ")) +(defun tensor-initial-offset-name (tensor) + (format nil "_~a_offset" (tensor-id tensor))) + (defun rest->alist (rest) (loop for i upfrom 0 below (length rest) by 2 if (not (find (nth i rest) `(:debug))) @@ -55,6 +58,9 @@ Base class for various aten backends. (let ((id2table (make-hash-table :test #'equal))) (loop for name in (bp-deps (aten-bp op)) do (setf (gethash (idkey name) id2table) name)) + (loop for tensor in (aten-inputs op) + for name = (idkey (tensor-initial-offset-name tensor)) + do (setf (gethash name id2table) #'(lambda () (tensor-initial-offset tensor)))) (setf (aten-scalars op) id2table)) (aten/ir::make-composite :inputs (remove-duplicates @@ -63,6 +69,9 @@ Base class for various aten backends. 'list (alexandria:compose #'aten/ir::%parse-aten #'tensor->shape-tracker) (aten-inputs op)) + (loop for tensor in (aten-inputs op) + collect + (aten/ir::%parse-aten (format nil "~a{Int}[]<>()" (tensor-initial-offset-name tensor)))) (loop for name in scalars collect (aten/ir::%parse-aten (coerce (format nil "~a{Int}[]<>()" name) '(simple-array character (*)))))) :test #'equal) @@ -110,15 +119,18 @@ Base class for various aten backends. collect (position out (aten-inputs (wf/vm:wfop-node wfir)) :test #'equal)))) (setf (wf/vm:wfop-op wfir) #'(lambda (&rest args) - ;;(print "++++++") - ;;(print (map 'list #'cl-waffe2/vm::maybe-observe-axis inputs)) - ;;(print (map 'list #'tensor-vec args)) - ;;(print (aten/ir:composite-code (aten/engine::cc-base-composite (aten-composite (wf/vm:wfop-node wfir))))) (apply (aten/engine::cc-caller (aten-composite (wf/vm:wfop-node wfir))) (append (map 'list #'tensor-vec args) - (map 'list #'cl-waffe2/vm::maybe-observe-axis inputs))) + (loop with c fixnum = 0 + for val in inputs + if (functionp val) + collect (prog1 + (tensor-initial-offset (nth c args)) + (incf c)) + else + collect (cl-waffe2/vm::maybe-observe-axis val)))) (apply #'values (loop for o in out-positions diff --git a/source/backends/aten/unary.lisp b/source/backends/aten/unary.lisp index e7c0f0a9..6e8f37db 100644 --- a/source/backends/aten/unary.lisp +++ b/source/backends/aten/unary.lisp @@ -32,7 +32,8 @@ (impl log2Node log2) (impl log10Node log10) (impl logeNode log) - (impl Log1pNode log1p)) + (impl Log1pNode log1p) + ) (define-impl (ExptNode :device Aten :extends (AtenOp)) :forward ((self x out n) diff --git a/source/vm/vm.lisp b/source/vm/vm.lisp index 6479becb..8d561032 100644 --- a/source/vm/vm.lisp +++ b/source/vm/vm.lisp @@ -59,7 +59,7 @@ This parameter is useful for printing how all instructions are performed. If set (s2a (actual-shape var))) (setf (cl-waffe2/vm.generic-tensor::tensor-visible-shape place) s1a (cl-waffe2/vm.generic-tensor::tensor-visible-shape var) s2a) - (unwind-protect (%vm-move place var) + (unwind-protect (%vm-move place var) (setf (cl-waffe2/vm.generic-tensor::tensor-visible-shape place) s1 (cl-waffe2/vm.generic-tensor::tensor-visible-shape var) s2))))))) nil) @@ -79,6 +79,7 @@ This parameter is useful for printing how all instructions are performed. If set (or (when state (cl-waffe2/vm.generic-tensor::statecontainer-forward-result state)) tensor))) + ;;(setf (tensor-initial-offset res) (tensor-initial-offset tensor)) (the AbstractTensor res))))) (declaim (ftype (function (list list) t) write-result))