Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions teeplot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Top-level package for teeplot."""

__author__ = """Matthew Andres Moreno"""
__email__ = 'm.more500@gmail.com'
__version__ = '1.2.0'
__email__ = "m.more500@gmail.com"
__version__ = "1.2.0"
34 changes: 33 additions & 1 deletion teeplot/teeplot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import abc, Counter
from contextlib import contextmanager
import copy
import functools
import os
import pathlib
import typing
Expand Down Expand Up @@ -359,7 +360,7 @@ def save_callback():


@contextmanager
def teed(*args: list, **kwargs: dict):
def teed(*args, **kwargs):
"""Context manager interface to `teeplot.tee`.

Plot save is dispatched upon exiting the context. Return value is the
Expand All @@ -377,3 +378,34 @@ def teed(*args: list, **kwargs: dict):
yield handle
finally:
saveit()


def teewrap(
**teeplot_kwargs: object,
):
"""Decorator interface to `teeplot.tee`

Works by returning a decorator that wraps `f` by calling `teeplot.tee` using
`f` and any passed in arguments and keyword arguments. However, using
`teeplot_outattrs` like in `teeplot.tee` will cause printed attributes to be
the same across function calls. For printing attributes on a per-call basis,
see `teeplot_outinclude` in `teeplot.tee`.
"""
if not all(k.startswith("teeplot_") for k in teeplot_kwargs):
raise ValueError(
"The `teewrap` decorator only accepts teeplot_* keyword arguments"
)

def decorator(f: typing.Callable):
@functools.wraps(f)
def inner(*args, **kwargs):
return tee(
f,
*args,
**teeplot_kwargs,
**kwargs,
)

return inner

return decorator
96 changes: 96 additions & 0 deletions tests/test_teewrap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#!/usr/bin/env python

'''
`tee` tests for `teeplot` package.
'''

import functools
import os

import numpy as np
import pytest
import seaborn as sns

from teeplot import teeplot as tp


@tp.teewrap(
teeplot_outattrs={
'additional' : 'teedmetadata',
'for' : 'output-filename',
'_one-for' : 'exclusion',
},
)
@functools.wraps(sns.lineplot)
def teed_snslineplot_outattrs(*args, **kwargs):
return sns.lineplot(*args, **kwargs)

def test():

teed_snslineplot_outattrs(
x='timepoint',
y='signal',
hue='region',
style='event',
data=sns.load_dataset('fmri'),
)

for ext in '.pdf', '.png':
assert os.path.exists(
os.path.join('teeplots', f'additional=teedmetadata+for=output-filename+hue=region+style=event+viz=lineplot+x=timepoint+y=signal+ext={ext}'),
)


@pytest.mark.parametrize("format", [".png", ".pdf", ".ps", ".eps", ".svg"])
def test_outformat(format):

# adapted from https://seaborn.pydata.org/generated/seaborn.lineplot.html
np.random.seed(1)
x, y = np.random.normal(size=(2, 5000)).cumsum(axis=1)

@tp.teewrap(
teeplot_outattrs={
'outformat' : 'teedmetadata',
},
teeplot_subdir='mydirectory',
teeplot_save={format},
)
@functools.wraps(sns.lineplot)
def teed_lineplot_outformat(*args, **kwargs):
return sns.lineplot(*args, **kwargs)

teed_lineplot_outformat(
x=x,
y=y,
sort=False,
lw=1,
)

assert os.path.exists(
os.path.join('teeplots', 'mydirectory', f'outformat=teedmetadata+viz=lineplot+ext={format}'),
)


@tp.teewrap(teeplot_outinclude=['a', 'b'])
@functools.wraps(sns.lineplot)
def teed_snslineplot_extra_args(*args, a, b, **kwargs):
return sns.lineplot(*args, **kwargs)


@pytest.mark.parametrize('a', [False, 1, 1])
@pytest.mark.parametrize('b', ['asdf', ''])
def test_included_outattrs(a, b):

teed_snslineplot_extra_args(
a=a,
b=b,
x='timepoint',
y='signal',
hue='region',
data=sns.load_dataset('fmri'),
)

for ext in '.pdf', '.png':
assert os.path.exists(
os.path.join('teeplots', f'a={a}+b={b}+hue=region+viz=lineplot+x=timepoint+y=signal+ext={ext}'),
)
Loading