Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule][M3b] Argument Info #9059

Merged
merged 1 commit into from
Sep 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions include/tvm/meta_schedule/arg_info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_META_SCHEDULE_ARG_INFO_H_
#define TVM_META_SCHEDULE_ARG_INFO_H_

#include <tvm/node/node.h>
#include <tvm/runtime/container/shape_tuple.h>
#include <tvm/tir/function.h>

namespace tvm {
namespace meta_schedule {

/*! \brief The argument information. */
class ArgInfoNode : public runtime::Object {
public:
static constexpr const char* _type_key = "meta_schedule.ArgInfo";
TVM_DECLARE_BASE_OBJECT_INFO(ArgInfoNode, runtime::Object);

public:
/*! \brief Default destructor. */
virtual ~ArgInfoNode() = default;
/*! \brief Converts the ArgInfo to its corresponding JSON representation. */
virtual ObjectRef AsJSON() const = 0;
};

/*!
* \brief Managed reference to ArgInfoNode
* \sa ArgInfoNode
*/
class ArgInfo : public runtime::ObjectRef {
public:
/*!
* \brief Parse the argument information from a JSON object.
* \param json_obj The json object to parse.
* \return The argument information parsed.
*/
TVM_DLL static ArgInfo FromJSON(const ObjectRef& json_obj);
/*!
* \brief Extract a list of the argument information from PrimFunc.
* \param func The PrimFunc to get argument information from.
* \return An array of the argument information derived.
*/
TVM_DLL static Array<ArgInfo, void> FromPrimFunc(const tir::PrimFunc& func);

TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ArgInfo, runtime::ObjectRef, ArgInfoNode);

protected:
ArgInfo() = default;
};

/*! \brief The tensor argument information. */
class TensorInfoNode : public ArgInfoNode {
public:
/*! \brief The data type of the tensor. */
runtime::DataType dtype;
/*! \brief The shape of the tensor. */
runtime::ShapeTuple shape;
junrushao marked this conversation as resolved.
Show resolved Hide resolved

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("shape", &shape);
}

static constexpr const char* _type_key = "meta_schedule.TensorInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorInfoNode, ArgInfoNode);

public:
ObjectRef AsJSON() const;
};

/*!
* \brief Managed reference to TensorInfoNode
* \sa TensorInfoNode
*/
class TensorInfo : public ArgInfo {
public:
/*!
* \brief Constructor of TensorInfo.
* \param dtype The data type of the tensor argument.
* \param shape The shape tuple of the tensor argument.
*/
TVM_DLL explicit TensorInfo(runtime::DataType dtype, runtime::ShapeTuple shape);
/*!
* \brief Parse the argument information from a JSON object.
* \param json_obj The json object to parse.
* \return The argument information parsed.
*/
TVM_DLL static TensorInfo FromJSON(const ObjectRef& json_obj);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorInfo, ArgInfo, TensorInfoNode);
};

} // namespace meta_schedule
} // namespace tvm

#endif // TVM_META_SCHEDULE_ARG_INFO_H_
10 changes: 9 additions & 1 deletion include/tvm/runtime/container/map.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <utility>

#include "./base.h"
#include "./optional.h"

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -1344,7 +1345,14 @@ class Map : public ObjectRef {
iterator end() const { return iterator(GetMapNode()->end()); }
/*! \return find the key and returns the associated iterator */
iterator find(const K& key) const { return iterator(GetMapNode()->find(key)); }

/*! \return The value associated with the key, NullOpt if not found */
Optional<V> Get(const K& key) const {
MapNode::iterator iter = GetMapNode()->find(key);
if (iter == GetMapNode()->end()) {
return NullOptType{};
}
return DowncastNoCheck<V>(iter->second);
}
void erase(const K& key) { CopyOnWrite()->erase(key); }

/*!
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@
# under the License.
"""Package `tvm.meta_schedule`. The meta schedule infrastructure."""
from . import builder
from . import arg_info
from .tune_context import TuneContext
106 changes: 106 additions & 0 deletions python/tvm/meta_schedule/arg_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""The argument information"""
from typing import Any, List, Union

from tvm._ffi import register_object
from tvm.runtime import DataType, Object, ShapeTuple
from tvm.tir import PrimFunc

from . import _ffi_api
from .utils import _json_de_tvm


@register_object("meta_schedule.ArgInfo")
class ArgInfo(Object):
"""Argument information"""

def as_json(self) -> Any:
"""Converts the ArgInfo to its corresponding JSON representation."""
return _json_de_tvm(_ffi_api.ArgInfoAsJSON(self)) # type: ignore # pylint: disable=no-member

@staticmethod
def from_json(json_obj: Any) -> "ArgInfo":
"""Parse the argument information from a JSON object.
Parameters
----------
json_obj : Any
The json object to parse.
Returns
-------
parsed : ArgInfo
The argument information parsed.
"""
return _ffi_api.ArgInfoFromJSON(json_obj) # type: ignore # pylint: disable=no-member

@staticmethod
def from_prim_func(func: PrimFunc) -> List["ArgInfo"]:
"""Extract a list of the argument information from PrimFunc.
Parameters
----------
func : PrimFunc
The PrimFunc to get argument information from.
Returns
-------
extracted : List[ArgInfo]
An array of the argument information derived.
"""
return _ffi_api.ArgInfoFromPrimFunc(func) # type: ignore # pylint: disable=no-member


@register_object("meta_schedule.TensorInfo")
class TensorInfo(ArgInfo):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we support scalar argument?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes definitely! We can write a new subclass of ArgInfo for this :-)

"""Tensor argument information
Parameters
----------
dtype : DataType
The data type of the tensor.
shape : ShapeTuple
The shape of the tensor.
"""

dtype: DataType
shape: ShapeTuple

def __init__(
self,
dtype: DataType,
shape: Union[ShapeTuple, List[int]],
) -> None:
"""Constructor
Parameters
----------
dtype : DataType
The data type of the tensor.
shape : ShapeTuple
The shape of the tensor.
"""
if isinstance(shape, ShapeTuple):
shape_tuple = shape
else:
shape_tuple = ShapeTuple(shape)
self.__init_handle_by_constructor__(
_ffi_api.TensorInfo, # type: ignore # pylint: disable=no-member
dtype,
shape_tuple,
)
33 changes: 32 additions & 1 deletion python/tvm/meta_schedule/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
"""Utilities for meta schedule"""
import os
import shutil
from typing import Callable, Union
from typing import Any, Callable, Union

import psutil

from tvm._ffi import get_global_func, register_func
from tvm.error import TVMError
from tvm.ir import Array, Map
from tvm.runtime import String
from tvm.tir import FloatImm, IntImm


@register_func("meta_schedule.cpu_count")
Expand Down Expand Up @@ -95,3 +98,31 @@ def get_global_func_with_default_on_worker(
def remove_build_dir(artifact_path: str) -> None:
"""Clean up the build directory"""
shutil.rmtree(os.path.dirname(artifact_path))


def _json_de_tvm(obj: Any) -> Any:
"""Unpack a TVM nested container to a JSON object in python.
Parameters
----------
obj : Any
The TVM nested container to be unpacked.
Returns
-------
result : Any
The unpacked json object.
"""
if obj is None:
return None
if isinstance(obj, (int, float)):
return obj
if isinstance(obj, (IntImm, FloatImm)):
return obj.value
if isinstance(obj, (str, String)):
return str(obj)
if isinstance(obj, Array):
return [_json_de_tvm(i) for i in obj]
if isinstance(obj, Map):
return {_json_de_tvm(k): _json_de_tvm(v) for k, v in obj.items()}
raise TypeError("Not supported type: " + str(type(obj)))
2 changes: 1 addition & 1 deletion python/tvm/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@
from .ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, ext_dev
from .module import load_module, enabled, system_lib
from .container import String
from .container import String, ShapeTuple
from .params import save_param_dict, load_param_dict
Loading