-
Notifications
You must be signed in to change notification settings - Fork 12
Constrained scan update #44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
|
Sorry, I think this needs another rebase. LMK if you need help with it. Also, can you fix the tests? It should be as simple as changing the scan access. We should probably also take a look at the conventional engines before merge, shouldn't be too difficult. |
| # cast = to_real_dtype(sim.object.data.dtype) | ||
| xp = get_array_module(sim.scan.data) | ||
| update = xp.zeros_like(sim.scan.data, dtype=sim.scan.data.dtype) | ||
| for kind, weight in self.constraints.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this deterministic? It seems like it could apply the updates in arbitrary order, we may want to add sorted() if it matters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be deterministic, since the updates are summed and applied after each kind * weight is calculated from the common unconstrained update and previous scan state. Example:
update += scan_affine(sim.scan.data, state.previous) * weight
| self.constraints[kind] = getattr(props, kind) | ||
| self.total_weight = sum(self.constraints.values()) | ||
| # self.weight: t.Optional[float] | ||
| # self.type: t.Optional[str] | ||
| logger.info(f"Initialized scan constraint with kinds {list(self.constraints.keys())} and weights {list(self.constraints.values())} with total weight {self.total_weight:.4f}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the meaning of weights here? Does it make more sense to add weight for each constraint as a relaxation parameter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My idea for weights was to have the applied update at each step be configurable as a weighted sum of each constraint type, where each constraint would be calculated from the unconstrained update.
ideally, it would assert that sum(weights) == 1, but I instead just renormalize to the total sum of all the specified values from the yaml
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed so default update = 1 - sum(constraints)
| if kind == 'affine': | ||
| update += scan_affine(sim.scan.data, state.previous) * weight | ||
| # sims.object.data = ## affine deform object | ||
| if kind == 'line' and state.row_bins is not None: | ||
| update += scan_line(sim.scan.data, state.previous, state.row_bins) * weight | ||
| if kind == 'hpf': | ||
| pass | ||
| if kind == 'lpf': | ||
| pass | ||
| if kind == 'default': | ||
| update += scan_default(sim.scan.data, state.previous) * weight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be better as a dictionary of update functions. It could also be a match, but I don't remember what our minimum supported version is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like match-case was introduced in 3.10, and the specification here is >=3.10.
I haven't checked if a 3.10 + all dependencies actually works though. Either way, I think it makes sense to change to match-case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to match statement
|
|
||
| ## double check that if position update is off (scan == prev_step), this doesn't break anything | ||
| # @partial(jit, donate_argnames=('pos',), cupy_fuse=True) | ||
| def scan_default( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should be private functions (e.g. _scan_default())
| left = xp.matmul(pos_prev.T, disp_update) | ||
| right = xp.matmul(pos_prev.T, pos_prev) | ||
| A = xp.matmul(xp.linalg.inv(right), left) | ||
| constraint = xp.matmul(pos_prev, A) | ||
| #remove the middle shift, keep the middle unchanged | ||
| center_ones = xp.ones((1, 1), pos.dtype) | ||
| # center[0, 0:2] = xp.average(pos, axis = 0) | ||
| center = xp.concatenate([xp.average(pos, axis = 0, keepdims=True), center_ones], axis=1, dtype=pos.dtype) | ||
| center_shift = xp.matmul(center, A) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xp.matmul(x, y) should be replacable as x @ y
|
|
||
| cost = xp.sum(xp.abs(sim.object.data - 1.0)) | ||
| cost_scale = xp.array(group.shape[-1] / prod(sim.scan.shape[:-1]), dtype=cost.dtype) | ||
| cost_scale = xp.array(group.shape[-1] / prod(sim.scan.data.shape[:-1]), dtype=cost.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe unnecessary for this PR, but we could probably make n_pos() a method of sim.scan to avoid this repetition
phaser/execute.py
Outdated
| ## FIXME: the scan normalization here - happens before dropnans and scan data flattening, but may alter shape and therefore rows/cols? why is this needed | ||
| def _normalize_scan_shape( | ||
| patterns: Patterns, state: ReconsState | ||
| ) -> t.Tuple[Patterns, ReconsState]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This basically is for scan and patterns from heterogeneous sources, i.e. one from previous state. It's a bit of a hack, but should be possible to adapt.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll update and remove all my comments. They were to help myself
|
|
||
| ## FIXME: output to Tuple? importance of array number types | ||
|
|
||
| @t.overload |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't love this API, not sure what would be better. Maybe output ScanState directly? Or maybe better, keep make_raster_scan clean and include the metadata in the hook only
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm in favor of outputting to ScanState directly.
The grid would need to be re-calculated if not included here, but maybe it would be fine.
77911c3 to
ef9f35e
Compare
breaking with all previous h5 files, since they do not contain scan object
ef9f35e to
3bca912
Compare
This has been rebased on latest probe aberration merge. Still need to update and test conventional solvers
Example: affine-only updates after 800 iterations for an experiment:
