Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Target argument for scale, map etc #190

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions src/high-level/abstract-tensor.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ In the event TARGET is not specified, the result may return an array sharing mem
(lambda (&rest dims)
(apply #'(setf tref) (funcall function (apply #'tref source dims)) target dims))))))

(defgeneric map (function tensor)
(:documentation "Map elements of TENSOR, storing the output of FUNCTION on the element into the corresponding element of a new tensor")
(:method ((function function) (tensor abstract-tensor))
(map! function (deep-copy-tensor tensor))))
(defgeneric map (function tensor &optional target)
(:documentation "Map elements of TENSOR, storing the output of FUNCTION on the element into the corresponding element of TARGET if it is supplied or else a new tensor")
(:method ((function function) (tensor abstract-tensor &optional (target nil targetp)))
(map! function (if targetp target (deep-copy-tensor tensor)))))

(defgeneric into (function tensor)
(:documentation "Map indices of TENSOR, storing the output of FUNCTION on the index into the corresponding element of a new tensor
Expand All @@ -132,12 +132,17 @@ If LAYOUT is specified then traverse TENSOR in the specified order (column major
(lambda (tensor factor)
(map! (lambda (x) (* x factor)) tensor)))

(define-backend-function scale (tensor factor)
"Scale TENSOR by FACTOR, returning a new tensor of the same type as TENSOR")
(define-backend-function scale (tensor factor &optional target)
"Scale TENSOR by FACTOR.
If TARGET is specified then the result is stored in TARGET,
otherwise a new tensor of the same type as TENSOR is used for the result.")

(define-backend-implementation scale :lisp
(lambda (tensor factor)
(scale! (deep-copy-tensor tensor) factor)))
(lambda (tensor factor &optional (target nil targetp))
(scale! (if targetp
target
(deep-copy-tensor tensor))
factor)))

(defgeneric slice (tensor from to)
(:documentation "Slice a tensor from FROM to TO, returning a new tensor with the contained elements")
Expand Down