Implement strum(flatten) for EnumIter#425
Implement strum(flatten) for EnumIter#425juliancoffee wants to merge 1 commit intoPeternator7:masterfrom
Conversation
ff6eab2 to
7b81796
Compare
|
This is the manual implementation of what these macros generate. It has some #[derive(Debug, Eq, PartialEq)]
enum Vibe {
Weak,
Average,
Strong,
}
impl Vibe {
fn iter() -> <Self as IntoIterator>::IntoIter {
let vibe = Vibe::Weak;
vibe.into_iter()
}
}
impl IntoIterator for Vibe {
type Item = Vibe;
type IntoIter = std::vec::IntoIter<Vibe>;
fn into_iter(self) -> Self::IntoIter {
vec![Vibe::Weak, Vibe::Average, Vibe::Strong].into_iter()
}
}
const SHADE_NUM: usize = 5;
#[derive(Debug, Eq, PartialEq)]
enum Shade {
Light,
Med1(Vibe),
Med2(Vibe),
Med3(Vibe),
Dark,
}
impl Shade {
fn iter() -> ShadeIter {
ShadeIter {
idx: 0,
med1_iter: Some(Vibe::iter()),
med2_iter: Some(Vibe::iter()),
med3_iter: Some(Vibe::iter()),
back_idx: 0,
}
}
}
impl Shade {
fn simple_iter() -> impl DoubleEndedIterator<Item = Shade> {
vec![Shade::Light]
.into_iter()
.chain(Vibe::iter().map(Shade::Med1))
.chain(Vibe::iter().map(Shade::Med2))
.chain(Vibe::iter().map(Shade::Med3))
.chain(vec![Shade::Dark])
}
}
struct ShadeIter {
idx: usize,
med1_iter: Option<<Vibe as IntoIterator>::IntoIter>,
med2_iter: Option<<Vibe as IntoIterator>::IntoIter>,
med3_iter: Option<<Vibe as IntoIterator>::IntoIter>,
back_idx: usize,
}
#[derive(Debug)]
enum Res {
Done(Shade),
DoneStep(Shade),
EndStep,
End,
}
impl ShadeIter {
fn nested_get(
nested_iter: &mut Option<<Vibe as IntoIterator>::IntoIter>,
wrap: fn(<Vibe as IntoIterator>::Item) -> Shade,
forward: bool,
) -> Res {
let next_inner = if forward {
nested_iter.as_mut().and_then(|t| t.next())
} else {
nested_iter.as_mut().and_then(|t| t.next_back())
};
if let Some(it) = next_inner {
Res::DoneStep(wrap(it))
} else {
nested_iter.take();
Res::EndStep
}
}
fn get(&mut self, idx: usize, forward: bool) -> Res {
let res = match dbg!(idx) {
0 => Res::Done(Shade::Light),
1 => Self::nested_get(&mut self.med1_iter, Shade::Med1, forward),
2 => Self::nested_get(&mut self.med2_iter, Shade::Med2, forward),
3 => Self::nested_get(&mut self.med3_iter, Shade::Med3, forward),
4 => Res::Done(Shade::Dark),
_ => Res::End,
};
dbg!(res)
}
}
impl Iterator for ShadeIter {
type Item = Shade;
fn next(&mut self) -> Option<Self::Item> {
self.nth(0)
}
fn nth(&mut self, n: usize) -> Option<Self::Item> {
if self.back_idx + self.idx >= SHADE_NUM {
return None;
}
match ShadeIter::get(self, dbg!(self.idx) + dbg!(n), true) {
Res::Done(x) => {
// move to requested, and past it
self.idx += n + 1;
Some(x)
}
Res::DoneStep(x) => {
// move to requested, but not past it
self.idx += n;
Some(x)
}
Res::EndStep => {
// ok, this one failed, move past it and request again
self.idx += 1;
let res = self.nth(0);
res
}
Res::End => None,
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
/*
let min = if self.idx + self.back_idx >= SHADE_NUM {
0
} else {
SHADE_NUM - self.idx - self.back_idx
};
*/
let med1_size = self.med1_iter.as_ref().map_or(0, |t| {
t.len()
});
let med2_size = self.med2_iter.as_ref().map_or(0, |t| {
t.len()
});
let med3_size = self.med3_iter.as_ref().map_or(0, |t| {
t.len()
});
let t = SHADE_NUM
+ dbg!(med1_size) - self.med1_iter.as_ref().map_or(0, |_| 1)
+ dbg!(med2_size) - self.med2_iter.as_ref().map_or(0, |_| 1)
+ dbg!(med3_size) - self.med3_iter.as_ref().map_or(0, |_| 1)
- dbg!(self.idx)
- dbg!(self.back_idx);
(t, Some(t))
}
}
impl ShadeIter {
fn nth_back(&mut self, back_n: usize) -> Option<Shade> {
if self.back_idx + self.idx >= SHADE_NUM {
return None;
}
let res = match ShadeIter::get(
self,
SHADE_NUM - dbg!(self.back_idx) - back_n - 1,
false,
) {
Res::Done(x) => {
// move to requested, and past it
self.back_idx += 1;
Some(x)
}
Res::DoneStep(x) => {
// move to requested, but not past it
Some(x)
}
Res::EndStep => {
// ok, this one failed, try the next one
self.back_idx += 1;
self.nth_back(0)
}
Res::End => None,
};
res
}
}
impl DoubleEndedIterator for ShadeIter {
fn next_back(&mut self) -> Option<Self::Item> {
self.nth_back(0)
}
}
impl ExactSizeIterator for ShadeIter {
fn len(&self) -> usize {
self.size_hint().0
}
}
fn main() {
println!("Hello, world!");
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn flatten() {
let result = Shade::iter().collect::<Vec<_>>();
let expected = vec![
Shade::Light,
Shade::Med1(Vibe::Weak),
Shade::Med1(Vibe::Average),
Shade::Med1(Vibe::Strong),
Shade::Med2(Vibe::Weak),
Shade::Med2(Vibe::Average),
Shade::Med2(Vibe::Strong),
Shade::Med3(Vibe::Weak),
Shade::Med3(Vibe::Average),
Shade::Med3(Vibe::Strong),
Shade::Dark,
];
assert_eq!(result, expected);
}
#[test]
fn flatten_back() {
let result = Shade::iter().rev().collect::<Vec<_>>();
let expected = vec![
Shade::Dark,
Shade::Med3(Vibe::Strong),
Shade::Med3(Vibe::Average),
Shade::Med3(Vibe::Weak),
Shade::Med2(Vibe::Strong),
Shade::Med2(Vibe::Average),
Shade::Med2(Vibe::Weak),
Shade::Med1(Vibe::Strong),
Shade::Med1(Vibe::Average),
Shade::Med1(Vibe::Weak),
Shade::Light,
];
assert_eq!(result, expected);
}
#[test]
fn iter_mixed_next_and_next_back() {
let mut iter = Shade::iter();
assert_eq!(iter.next(), Some(Shade::Light));
assert_eq!(iter.next_back(), Some(Shade::Dark));
assert_eq!(iter.next(), Some(Shade::Med1(Vibe::Weak)));
assert_eq!(iter.next_back(), Some(Shade::Med3(Vibe::Strong)));
assert_eq!(iter.next(), Some(Shade::Med1(Vibe::Average)));
assert_eq!(iter.next_back(), Some(Shade::Med3(Vibe::Average)));
assert_eq!(iter.next(), Some(Shade::Med1(Vibe::Strong)));
assert_eq!(iter.next_back(), Some(Shade::Med3(Vibe::Weak)));
assert_eq!(iter.next(), Some(Shade::Med2(Vibe::Weak)));
assert_eq!(iter.next_back(), Some(Shade::Med2(Vibe::Strong)));
assert_eq!(iter.next(), Some(Shade::Med2(Vibe::Average)));
assert_eq!(iter.next_back(), None);
}
#[test]
fn iter_quickheck() {
use rand::Rng;
let mut rng = rand::rng();
for _ in 0..1000 {
let mut iter = Shade::iter();
let mut simple_iter = Shade::simple_iter();
let mut results = vec![];
let mut expected = vec![];
for _ in 0..500 {
if rng.random_bool(0.5) {
results.push(iter.next());
expected.push(simple_iter.next());
} else {
results.push(iter.next_back());
expected.push(simple_iter.next_back());
}
}
assert_eq!(results, expected);
}
}
#[test]
fn iter_quickheck_sizehint() {
use rand::Rng;
let mut rng = rand::rng();
for _ in 0..1000 {
let mut iter = Shade::iter();
let mut simple_iter = Shade::simple_iter();
assert_eq!(dbg!(iter.size_hint()), simple_iter.size_hint());
for _ in 0..500 {
if rng.random_bool(0.5) {
dbg!("next");
_ = iter.next();
_ = simple_iter.next();
assert_eq!(dbg!(iter.size_hint()), simple_iter.size_hint());
} else {
dbg!("next_back");
_ = iter.next_back();
_ = simple_iter.next_back();
assert_eq!(dbg!(iter.size_hint()), simple_iter.size_hint());
}
}
}
}
#[test]
fn iter_quickheck_len() {
use rand::Rng;
let mut rng = rand::rng();
for _ in 0..1000 {
let mut iter = Shade::iter();
const MAX: usize = 11;
assert_eq!(dbg!(iter.len()), MAX);
for i in 1..=MAX {
if rng.random_bool(0.5) {
dbg!("next");
_ = iter.next();
} else {
dbg!("next_back");
_ = iter.next_back();
}
assert_eq!(dbg!(iter.len()), MAX - i);
}
}
}
}Open to your comments 🙌 |
There was a problem hiding this comment.
Hi, thx for this PR I also need this but didn't get the time to look into it much, you're saving me big time !
One concern I have is that the generated code isn't compatible with no_std anymore due to the vec![].
Quick testing shows that a simple array also does the trick
Updated example, nothing much changes, every vec![] is replaced by [] and
type IntoIter = std::vec::IntoIter<Vibe>; becomes type IntoIter = <[Self; 3] as core::iter::IntoIterator>::IntoIter; which could get tricky, maybe generate an associated constant containing the number of variants (<[Self; 4 + 3] as core::iter::IntoIterator>::IntoIter works) ?
#![no_std]
#[derive(Debug, Eq, PartialEq)]
enum Vibe {
Weak,
Average,
Strong,
}
impl Vibe {
fn iter() -> <Self as IntoIterator>::IntoIter {
let vibe = Vibe::Weak;
vibe.into_iter()
}
}
impl IntoIterator for Vibe {
type Item = Vibe;
type IntoIter = <[Self; 3] as core::iter::IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
[Vibe::Weak, Vibe::Average, Vibe::Strong].into_iter()
}
}
const SHADE_NUM: usize = 5;
#[derive(Debug, Eq, PartialEq)]
enum Shade {
Light,
Med1(Vibe),
Med2(Vibe),
Med3(Vibe),
Dark,
}
impl Shade {
fn iter() -> ShadeIter {
ShadeIter {
idx: 0,
med1_iter: Some(Vibe::iter()),
med2_iter: Some(Vibe::iter()),
med3_iter: Some(Vibe::iter()),
back_idx: 0,
}
}
}
impl Shade {
fn simple_iter() -> impl DoubleEndedIterator<Item = Shade> {
[Shade::Light]
.into_iter()
.chain(Vibe::iter().map(Shade::Med1))
.chain(Vibe::iter().map(Shade::Med2))
.chain(Vibe::iter().map(Shade::Med3))
.chain([Shade::Dark])
}
}
struct ShadeIter {
idx: usize,
med1_iter: Option<<Vibe as IntoIterator>::IntoIter>,
med2_iter: Option<<Vibe as IntoIterator>::IntoIter>,
med3_iter: Option<<Vibe as IntoIterator>::IntoIter>,
back_idx: usize,
}
#[derive(Debug)]
enum Res {
Done(Shade),
DoneStep(Shade),
EndStep,
End,
}
impl ShadeIter {
fn nested_get(
nested_iter: &mut Option<<Vibe as IntoIterator>::IntoIter>,
wrap: fn(<Vibe as IntoIterator>::Item) -> Shade,
forward: bool,
) -> Res {
let next_inner = if forward {
nested_iter.as_mut().and_then(|t| t.next())
} else {
nested_iter.as_mut().and_then(|t| t.next_back())
};
if let Some(it) = next_inner {
Res::DoneStep(wrap(it))
} else {
nested_iter.take();
Res::EndStep
}
}
fn get(&mut self, idx: usize, forward: bool) -> Res {
match idx {
0 => Res::Done(Shade::Light),
1 => Self::nested_get(&mut self.med1_iter, Shade::Med1, forward),
2 => Self::nested_get(&mut self.med2_iter, Shade::Med2, forward),
3 => Self::nested_get(&mut self.med3_iter, Shade::Med3, forward),
4 => Res::Done(Shade::Dark),
_ => Res::End,
}
}
}
impl Iterator for ShadeIter {
type Item = Shade;
fn next(&mut self) -> Option<Self::Item> {
self.nth(0)
}
fn nth(&mut self, n: usize) -> Option<Self::Item> {
if self.back_idx + self.idx >= SHADE_NUM {
return None;
}
match ShadeIter::get(self, self.idx + n, true) {
Res::Done(x) => {
// move to requested, and past it
self.idx += n + 1;
Some(x)
}
Res::DoneStep(x) => {
// move to requested, but not past it
self.idx += n;
Some(x)
}
Res::EndStep => {
// ok, this one failed, move past it and request again
self.idx += 1;
let res = self.nth(0);
res
}
Res::End => None,
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
/*
let min = if self.idx + self.back_idx >= SHADE_NUM {
0
} else {
SHADE_NUM - self.idx - self.back_idx
};
*/
let med1_size = self.med1_iter.as_ref().map_or(0, |t| {
t.len()
});
let med2_size = self.med2_iter.as_ref().map_or(0, |t| {
t.len()
});
let med3_size = self.med3_iter.as_ref().map_or(0, |t| {
t.len()
});
let t = SHADE_NUM
+ (med1_size) - self.med1_iter.as_ref().map_or(0, |_| 1)
+ (med2_size) - self.med2_iter.as_ref().map_or(0, |_| 1)
+ (med3_size) - self.med3_iter.as_ref().map_or(0, |_| 1)
- (self.idx)
- (self.back_idx);
(t, Some(t))
}
}
impl ShadeIter {
fn nth_back(&mut self, back_n: usize) -> Option<Shade> {
if self.back_idx + self.idx >= SHADE_NUM {
return None;
}
let res = match ShadeIter::get(
self,
SHADE_NUM - self.back_idx - back_n - 1,
false,
) {
Res::Done(x) => {
// move to requested, and past it
self.back_idx += 1;
Some(x)
}
Res::DoneStep(x) => {
// move to requested, but not past it
Some(x)
}
Res::EndStep => {
// ok, this one failed, try the next one
self.back_idx += 1;
self.nth_back(0)
}
Res::End => None,
};
res
}
}
impl DoubleEndedIterator for ShadeIter {
fn next_back(&mut self) -> Option<Self::Item> {
self.nth_back(0)
}
}
impl ExactSizeIterator for ShadeIter {
fn len(&self) -> usize {
self.size_hint().0
}
}
const fn main() {
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn flatten() {
let result = Shade::iter().collect::<Vec<_>>();
let expected = vec![
Shade::Light,
Shade::Med1(Vibe::Weak),
Shade::Med1(Vibe::Average),
Shade::Med1(Vibe::Strong),
Shade::Med2(Vibe::Weak),
Shade::Med2(Vibe::Average),
Shade::Med2(Vibe::Strong),
Shade::Med3(Vibe::Weak),
Shade::Med3(Vibe::Average),
Shade::Med3(Vibe::Strong),
Shade::Dark,
];
assert_eq!(result, expected);
}
#[test]
fn flatten_back() {
let result = Shade::iter().rev().collect::<Vec<_>>();
let expected = vec![
Shade::Dark,
Shade::Med3(Vibe::Strong),
Shade::Med3(Vibe::Average),
Shade::Med3(Vibe::Weak),
Shade::Med2(Vibe::Strong),
Shade::Med2(Vibe::Average),
Shade::Med2(Vibe::Weak),
Shade::Med1(Vibe::Strong),
Shade::Med1(Vibe::Average),
Shade::Med1(Vibe::Weak),
Shade::Light,
];
assert_eq!(result, expected);
}
#[test]
fn iter_mixed_next_and_next_back() {
let mut iter = Shade::iter();
assert_eq!(iter.next(), Some(Shade::Light));
assert_eq!(iter.next_back(), Some(Shade::Dark));
assert_eq!(iter.next(), Some(Shade::Med1(Vibe::Weak)));
assert_eq!(iter.next_back(), Some(Shade::Med3(Vibe::Strong)));
assert_eq!(iter.next(), Some(Shade::Med1(Vibe::Average)));
assert_eq!(iter.next_back(), Some(Shade::Med3(Vibe::Average)));
assert_eq!(iter.next(), Some(Shade::Med1(Vibe::Strong)));
assert_eq!(iter.next_back(), Some(Shade::Med3(Vibe::Weak)));
assert_eq!(iter.next(), Some(Shade::Med2(Vibe::Weak)));
assert_eq!(iter.next_back(), Some(Shade::Med2(Vibe::Strong)));
assert_eq!(iter.next(), Some(Shade::Med2(Vibe::Average)));
assert_eq!(iter.next_back(), None);
}
#[test]
fn iter_quickheck() {
use rand::Rng;
let mut rng = rand::rng();
for _ in 0..1000 {
let mut iter = Shade::iter();
let mut simple_iter = Shade::simple_iter();
let mut results = vec![];
let mut expected = vec![];
for _ in 0..500 {
if rng.random_bool(0.5) {
results.push(iter.next());
expected.push(simple_iter.next());
} else {
results.push(iter.next_back());
expected.push(simple_iter.next_back());
}
}
assert_eq!(results, expected);
}
}
#[test]
fn iter_quickheck_sizehint() {
use rand::Rng;
let mut rng = rand::rng();
for _ in 0..1000 {
let mut iter = Shade::iter();
let mut simple_iter = Shade::simple_iter();
assert_eq!(dbg!(iter.size_hint()), simple_iter.size_hint());
for _ in 0..500 {
if rng.random_bool(0.5) {
dbg!("next");
_ = iter.next();
_ = simple_iter.next();
assert_eq!(dbg!(iter.size_hint()), simple_iter.size_hint());
} else {
dbg!("next_back");
_ = iter.next_back();
_ = simple_iter.next_back();
assert_eq!(dbg!(iter.size_hint()), simple_iter.size_hint());
}
}
}
}
#[test]
fn iter_quickheck_len() {
use rand::Rng;
let mut rng = rand::rng();
for _ in 0..1000 {
let mut iter = Shade::iter();
const MAX: usize = 11;
assert_eq!(dbg!(iter.len()), MAX);
for i in 1..=MAX {
if rng.random_bool(0.5) {
dbg!("next");
_ = iter.next();
} else {
dbg!("next_back");
_ = iter.next_back();
}
assert_eq!(dbg!(iter.len()), MAX - i);
}
}
}
}|
@vic1707 Vibe::iter() was added because I needed a nested iterator, and yeah, I didn't care much about its implementation, because it wouldn't be present in "real" code. Shade::simple_iter() is there so that I have something to compare results to without writing too many tests, so it wouldn't be present in generated code as well. Thanks for noting that, though. I guess the drawback of |
|
Sorry I for that misunderstanding on my part, good job, can't wait to see it land if the devs are ok 👍 |
| custom_keyword!(default_with); | ||
| custom_keyword!(props); | ||
| custom_keyword!(ascii_case_insensitive); | ||
| custom_keyword!(flatten); |
There was a problem hiding this comment.
honestly, the biggest concern I have here is how should #[strum(flatten)] interact with other derives
There was a problem hiding this comment.
basically this thing
Fixes #424
As I said, it's possible if slightly complex. I'm not an expert in writing iterators, though, so maybe it's possible to cut some rough edges; I just tried to make it correct.
I tried to produce a slim diff, but DoubleEndedIterator implementation went into pieces.
Also, you can see in tests,
Color::simple_iter()gives much simpler implementation, but maybe a bit slower to run and/or compile? I didn't bench it.UPD: I think I know how to simplify this a little (without going through implementation I pointed out above), so if you're interested I'll try to refactor it a bit