-
Notifications
You must be signed in to change notification settings - Fork 22
[passes] Add ConvertMatmulToLinear pass #341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
fddc18a to
72a3617
Compare
| def get_compile_config(self): | ||
| return CompileConfigV1(convert_lhs_const_mm_to_fc=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@glistening @seockho-kim Using this compile config will enable matmul op with lhs const node conversion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dayo09 Could you improve it to handle bmm, too?
@tag.use_onert
class BmmTest(TestModuleBase):
def __init__(self):
super().__init__()
self.weight = torch.randn(2, 3, 4)
def forward(self, rhs):
out = self.weight @ rhs
return out
def get_example_inputs(self):
return (torch.randn(2, 4, 5),), {}
def get_compile_config(self):
return CompileConfigV1(convert_lhs_const_mm_to_fc=True)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@seockho-kim Above case is not supported because matmul to fc conversion can be done only if weight is 2dim. Circle FullyConnected operation assumes its weight to be in rank 2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm sorry, I gave you wrong example.
I mean bmm(batch=1) case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me add it in the next PR !
| * Linear has better quantization accuracy (NPU backend) | ||
| Due to ONE compiler's quantization policy; | ||
| FullyConnected(=Linear) uses per-channel quantization for weight and per-tensor for input. | ||
| BatchMatmul(=matmul) uses per-tensor quantization for both rhs and lhs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, a new generation of NPU would support cwq for matmul.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jinevening Do you mean 3rd generation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.
| inputs = [input, other] | ||
| outputs = [node] | ||
|
|
||
| if not is_const(other) and prior_latency: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prior_latency is not used anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jinevening Yes, the old feature is basically part of the new one.
# BEFORE prior_latency==False (default)
# AFTER (default)
If rhs is const: conversion ON
else: conversion OFF
# BEFORE prior_latency==True
# AFTER convert_rhs_const_mm_to_fc ==False
always: conversion OFFThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Then it would be possible to remove that arg and related codes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed ;-D
| return fc_node | ||
|
|
||
|
|
||
| class ConvertLhsConstMatmulToLinear(Converter): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why should it consider lhs and rhs? The left and right matter when converting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mhs4670go onert doesn't run lhs const matmul. So this PR converts matmul to fullyconnected for onert, conditionally with config
|
@glistening I saw this comment. Is this pass still needed? |
@glistening @seockho-kim How do you think whether this pass is needed still? If it does, I plan to add pass for lowering bmm to mm pass in another pr. Please share your opinions. |
27bde5b to
9027726
Compare
@dayo09 Sure. I need your PR :). The maintainer of |
|
@jinevening @seockho-kim @mhs4670go PTAL :-D |
| """ """ | ||
|
|
||
| def __init__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| """ """ | |
| def __init__(self): | |
| def __init__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be good to describe what error is expected. NNFW_STATUS_ERROR is a bit ambiguous.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is how onert throws. It should match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I think using docstring or comments is also enough.
jinevening
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| convert_lhs_const_mm_to_fc: bool = False | ||
| convert_rhs_const_mm_to_fc: bool = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second thought, just convert_const_mm_to_fc could be simpler choice. Do you have any reasons that chose this design?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rhs_const_mm_to_fc doesn't have trade-off because tranpose is foldable to const, but lhs_const_mm_to_fc requires potential latency trade-off. Therefore, the user needs separate decisions on each case.
|
If I understand correctly, TICO will generate To remove the redundant |
We've decided to delegate graph optimization to one-optimize in order to avoid code duplication. Is it hard to use one-optimize? |
mhs4670go
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
Yes, I agree that ( I would like to check it works. However, as usual, I cannot access GitHub. I was just curious why |
|
Hmm. I succeed to @seockho-kim Does this PR solve the same issue in gemma3? |
|
@dayo09 Conflicts should be resolved.
config = tico.CompileConfigV1()
config.convert_lhs_const_mm_to_fc = True
circle_model = tico.convert(torch_module, example_inputs, config = config) |
Let's add convert matmul to linear pass. This commit... refactors mm serialization logic and make convert_matmul_to_linear pass introduces new CompileConfig attribute convert_lhs/rhs_const_mm_to_fc. TICO-DCO-1.0-Signed-off-by: Dayoung Lee <dayoung.lee@samsung.com>
Co-authored-by: Hyukjin Jeong <hj1.jeong@samsung.com>
9027726 to
4588a1b
Compare
|
@glistening Sorry for ambiguousness in my comment.
I mean, this PR doesn't support bmm to mm YET, it needs further PR. |
|
@@jinevening @seockho-kim @mhs4670go It's rebased. PTAL again 😅 |
This pr does not solve issue in gemma3, because bmm is not changed. |
seockho-kim
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
On a second thought, just adding @jinevening @mhs4670go Do you think conversion of |
I think it should be optional because it's not kind of circle legalization. |
+1 for the latter |
|
If
|
Let's add convert matmul to linear pass.
This commit...
refactors mm serialization logic and make convert_matmul_to_linear pass
introduces new CompileConfig attribute convert_lhs/rhs_const_mm_to_fc.
TICO-DCO-1.0-Signed-off-by: Dayoung Lee dayoung.lee@samsung.com
For #339