Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CUDA stream with stream memory pool #306

Merged
merged 32 commits into from
Nov 6, 2017
Merged

Conversation

sonots
Copy link
Contributor

@sonots sonots commented Jul 18, 2017

Fix #225

Support CUDA stream. stream can be specified with with statement or use() method as:

import cupy as cp

x_gpu = cp.array([1, 2, 3])

with cp.cuda.stream.Stream():
   y_gpu = cp.linalg.norm(x_gpu)

stream = cp.cuda.stream.Stream()
stream.use()
y_gpu = cp.linalg.norm(x_gpu)

nvprof --print-gpu-trace python ~/test_cupy.py shows cupy kernels are executed in another stream.

To support CUDA stream, this PR changes memory pool, too.
A memory pool is created for each stream separately so that parallel computations among streams do not touch memory used in another stream.

Note that cudaMalloc is issued on default stream always.

Incompatibility changes:

  • cupy.cuda.generator.RandomState: setStream() method was removed
  • cupy.cuda.stream.Stream(null=True) is prohibited to assure cupy.cuda.stream.Stream.null object is always used to specify the default stream.

@sonots sonots force-pushed the stream_pool branch 6 times, most recently from 9a73412 to ef4dc29 Compare July 19, 2017 08:24
@sonots
Copy link
Contributor Author

sonots commented Jul 19, 2017

Added implementation to support stream on all CUDA libraries, cudnn, cublas, curand, cusparse, cusolver, and thrust.

@sonots
Copy link
Contributor Author

sonots commented Jul 19, 2017

Because it is difficult to test whether CUDA is really using stream, I added examples and verified by running them with nvprof --print-gpu-trace.

@sonots
Copy link
Contributor Author

sonots commented Jul 19, 2017

Wrote about incompatibility changes #306 (comment)

@sonots sonots force-pushed the stream_pool branch 6 times, most recently from da8880c to 7791731 Compare July 19, 2017 15:53
@sonots sonots requested review from rezoo and removed request for rezoo July 19, 2017 15:59
@sonots sonots changed the title [WIP] Support CUDA stream with stream memory pool [WIP] Support CUDA stream for kernel executions with stream memory pool Jul 20, 2017
@sonots sonots changed the title [WIP] Support CUDA stream for kernel executions with stream memory pool [WIP] Support CUDA stream with stream memory pool Jul 20, 2017
@sonots sonots force-pushed the stream_pool branch 9 times, most recently from 4d62dad to 5ed26d7 Compare July 20, 2017 22:51
@sonots
Copy link
Contributor Author

sonots commented Oct 22, 2017

I am now thinking to add an interface to free_all_blocks of an arena for a given stream. <= done

@sonots sonots force-pushed the stream_pool branch 4 times, most recently from 76a85b5 to 8902608 Compare October 22, 2017 19:53
@sonots
Copy link
Contributor Author

sonots commented Oct 23, 2017

Jenkins fails cyfunction is not a Python function with cd docs && make html as:

/work/cupy/docs/source/reference/generated/cupy.cuda.Event.rst:53: WARNING: error while formatting arguments for cupy.cuda.Event.record: <cyfunction Event.record at 0x2aea1bcb8d10> is not a Python function

I am struggling to resolve this, but I have no clue how to resolve this issue.
I found pandas has same issue pandas-dev/pandas#5218, but it seems they have also no idea.

@sonots
Copy link
Contributor Author

sonots commented Oct 23, 2017

Hmm, it seems make html succeeds with python 3.6.2, but fails with python 2.7.13.

@sonots
Copy link
Contributor Author

sonots commented Oct 23, 2017

It seems insepect.isfunction(cyfuction) returns True from python 3.4 https://groups.google.com/forum/#!topic/cython-users/FvcMPk9n2X8. Can we use python >= 3.4 to generate our sphinx docs in chainer-test?

@sonots
Copy link
Contributor Author

sonots commented Oct 25, 2017

#306 (comment)

I've debugged sphinx. sphinx has config.suprress_warnings http://www.sphinx-doc.org/ja/stable/config.html, but cyfunction is not a Python function was not supported.

So, another way (one way is to use python >=3.4 ) to avoid is not to treat warnings as error with a change:

Makefile

-  6 SPHINXBUILD   = sphinx-build -W
+   6 SPHINXBUILD   = sphinx-build

This is another problem, but I noticed that number of doctests were increased as:

sphinx-build -W

Doctest summary
===============
   60 tests
    0 failures in tests
    0 failures in setup code
    0 failures in cleanup code

sphinx-build

Doctest summary
===============
   95 tests
    0 failures in tests
    0 failures in setup code
    0 failures in cleanup code

@sonots
Copy link
Contributor Author

sonots commented Oct 26, 2017

Memory.__init__ also get raised TypeError, but it seems it is not included in the warnings. It seems that it was reason why doctest did not fail before... 🤔

EDIT: hmm, looks __init__ are ignored.

('util/inspect.py:115', <cyfunction Memory.__init__ at 0x7f744a3cab90>)
('util/inspect.py:115', <cyfunction Stream.__init__ at 0x7f744afb6ad0>)
('util/inspect.py:115', <cyfunction Event.__init__ at 0x7f744afb6710>)
('util/inspect.py:115', <cyfunction ReductionKernel.__init__ at 0x7f744a3f64d0>)
('util/inspect.py:115', <cyfunction Event.__init__ at 0x7f744afb6710>)
('util/inspect.py:115', <cyfunction Event.record at 0x7f744afb6950>)
('ext/autodoc.py:751', <type 'exceptions.TypeError'>, TypeError('<cyfunction Event.record at 0x7f744afb6950> is not a Python function',))
('util/inspect.py:115', <cyfunction Event.synchronize at 0x7f744afb6a10>)
('ext/autodoc.py:751', <type 'exceptions.TypeError'>, TypeError('<cyfunction Event.synchronize at 0x7f744afb6a10> is not a Python function',))
('util/inspect.py:115', <cyfunction Memory.__init__ at 0x7f744a3cab90>)
('util/inspect.py:115', <cyfunction Stream.__init__ at 0x7f744afb6ad0>)
('util/inspect.py:115', <cyfunction Stream.__enter__ at 0x7f744afb6c50>)
('ext/autodoc.py:751', <type 'exceptions.TypeError'>, TypeError('<cyfunction Stream.__enter__ at 0x7f744afb6c50> is not a Python function',))
('util/inspect.py:115', <cyfunction Stream.__exit__ at 0x7f744afb6d10>)
('ext/autodoc.py:751', <type 'exceptions.TypeError'>, TypeError('<cyfunction Stream.__exit__ at 0x7f744afb6d10> is not a Python function',))
('util/inspect.py:115', <cyfunction Stream.add_callback at 0x7f744a3ca050>)
('ext/autodoc.py:751', <type 'exceptions.TypeError'>, TypeError('<cyfunction Stream.add_callback at 0x7f744a3ca050> is not a Python function',))
('util/inspect.py:115', <cyfunction Stream.record at 0x7f744a3ca110>)
('ext/autodoc.py:751', <type 'exceptions.TypeError'>, TypeError('<cyfunction Stream.record at 0x7f744a3ca110> is not a Python function',))
('util/inspect.py:115', <cyfunction Stream.synchronize at 0x7f744afb6f50>)
('ext/autodoc.py:751', <type 'exceptions.TypeError'>, TypeError('<cyfunction Stream.synchronize at 0x7f744afb6f50> is not a Python function',))
('util/inspect.py:115', <cyfunction Stream.use at 0x7f744afb6dd0>)
('ext/autodoc.py:751', <type 'exceptions.TypeError'>, TypeError('<cyfunction Stream.use at 0x7f744afb6dd0> is not a Python function',))
('util/inspect.py:115', <cyfunction Stream.wait_event at 0x7f744a3ca1d0>)
('ext/autodoc.py:751', <type 'exceptions.TypeError'>, TypeError('<cyfunction Stream.wait_event at 0x7f744a3ca1d0> is not a Python function',))
('util/inspect.py:115', <cyfunction ufunc.__init__ at 0x7f744a3e1b90>)
('util/inspect.py:115', <cyfunction ReductionKernel.__init__ at 0x7f744a3f64d0>)
('util/inspect.py:115', <cyfunction ufunc.__init__ at 0x7f744a3e1b90>)

@sonots
Copy link
Contributor Author

sonots commented Nov 6, 2017

jenkins, test this please

@okuta
Copy link
Member

okuta commented Nov 6, 2017

LGTM!

@okuta okuta merged commit 6d4bb04 into cupy:master Nov 6, 2017
@sonots sonots deleted the stream_pool branch November 7, 2017 03:37
@niboshi niboshi added the no-compat Disrupts backward compatibility label Jan 11, 2018
@sunbearc22
Copy link

Can you provide a simple example of how to use Nvidia Visual profiler nvvp to profile a python-cupy script? Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cat:feature New features/APIs no-compat Disrupts backward compatibility
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Running kernels in CUDA stream
4 participants