Skip to content

Commit 6702bc1

Browse files
committed
Fix flake8 warnings
1 parent 3c6597b commit 6702bc1

File tree

4 files changed

+31
-18
lines changed

4 files changed

+31
-18
lines changed

pyindicators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .indicators import sma, rsi, is_crossover, crossunder, ema, wilders_rsi, \
1+
from .indicators import sma, rsi, crossunder, ema, wilders_rsi, \
22
crossover, is_crossover, wma, macd, willr
33

44
__all__ = [

pyindicators/indicators/crossover.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@ def crossover(
99
data: Union[PdDataFrame, PlDataFrame],
1010
first_column: str,
1111
second_column: str,
12-
result_column = "crossover",
12+
result_column="crossover",
1313
data_points: int = None,
1414
strict: bool = True,
1515
) -> Union[PdDataFrame, PlDataFrame]:
1616
"""
17-
Identifies crossover points where `first_column` crosses above or below `second_column`.
17+
Identifies crossover points where `first_column` crosses above
18+
or below `second_column`.
1819
1920
Args:
2021
data: Pandas or Polars DataFrame
@@ -32,15 +33,18 @@ def crossover(
3233

3334
# Restrict data to the last `data_points` rows if specified
3435
if data_points is not None:
35-
data = data.tail(data_points) if isinstance(data, PdDataFrame) else data.slice(-data_points)
36+
data = data.tail(data_points) if isinstance(data, PdDataFrame) \
37+
else data.slice(-data_points)
3638

3739
# Pandas Implementation
3840
if isinstance(data, PdDataFrame):
3941
col1, col2 = data[first_column], data[second_column]
4042
prev_col1, prev_col2 = col1.shift(1), col2.shift(1)
4143

4244
if strict:
43-
crossover_mask = ((prev_col1 < prev_col2) & (col1 > col2)) | ((prev_col1 > prev_col2) & (col1 < col2))
45+
crossover_mask = (
46+
(prev_col1 < prev_col2)
47+
& (col1 > col2)) | ((prev_col1 > prev_col2) & (col1 < col2))
4448
else:
4549
crossover_mask = (col1 > col2) | (col1 < col2)
4650

@@ -52,12 +56,14 @@ def crossover(
5256
prev_col1, prev_col2 = col1.shift(1), col2.shift(1)
5357

5458
if strict:
55-
crossover_mask = ((prev_col1 < prev_col2) & (col1 > col2)) | ((prev_col1 > prev_col2) & (col1 < col2))
59+
crossover_mask = ((prev_col1 < prev_col2) & (col1 > col2)) | \
60+
((prev_col1 > prev_col2) & (col1 < col2))
5661
else:
5762
crossover_mask = (col1 > col2) | (col1 < col2)
5863

5964
# Convert boolean mask to 1s and 0s
60-
data = data.with_columns(pl.when(crossover_mask).then(1).otherwise(0).alias(result_column))
65+
data = data.with_columns(pl.when(crossover_mask).then(1)
66+
.otherwise(0).alias(result_column))
6167

6268
return data
6369

pyindicators/indicators/macd.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from typing import Union
22

3-
import numpy as np
43
from pandas import DataFrame as PdDataFrame
54
from polars import DataFrame as PlDataFrame
6-
import pandas as pd
75
import polars as pl
86

97
from pyindicators.exceptions import PyIndicatorException
@@ -21,20 +19,27 @@ def macd(
2119
histogram_column: str = "macd_histogram"
2220
) -> Union[PdDataFrame, PlDataFrame]:
2321
"""
24-
Calculate the MACD (Moving Average Convergence Divergence) for a given DataFrame.
22+
Calculate the MACD (Moving Average Convergence Divergence) for
23+
a given DataFrame.
2524
2625
Args:
27-
data (Union[pd.DataFrame, pl.DataFrame]): Input data containing the price series.
26+
data (Union[pd.DataFrame, pl.DataFrame]): Input data containing
27+
the price series.
2828
source_column (str): Column name for the price series.
29-
short_period (int, optional): Period for the short-term EMA (default: 12).
30-
long_period (int, optional): Period for the long-term EMA (default: 26).
31-
signal_period (int, optional): Period for the Signal Line EMA (default: 9).
29+
short_period (int, optional): Period for the short-term EMA
30+
(default: 12).
31+
long_period (int, optional): Period for the long-term EMA
32+
(default: 26).
33+
signal_period (int, optional): Period for the Signal Line
34+
EMA (default: 9).
3235
macd_column (str, optional): Column name to store the MACD line.
3336
signal_column (str, optional): Column name to store the Signal line.
34-
histogram_column (str, optional): Column name to store the MACD histogram.
37+
histogram_column (str, optional): Column name to store the
38+
MACD histogram.
3539
3640
Returns:
37-
Union[pd.DataFrame, pl.DataFrame]: DataFrame with MACD, Signal Line, and Histogram.
41+
Union[pd.DataFrame, pl.DataFrame]: DataFrame with MACD, Signal
42+
Line, and Histogram.
3843
"""
3944
if source_column not in data.columns:
4045
raise PyIndicatorException(

pyindicators/indicators/williams_percent_range.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,10 @@ def willr(
4949
return data.drop(columns=["high_n", "low_n"])
5050

5151
elif isinstance(data, pl.DataFrame):
52-
high_n = data.select(pl.col(high_column).rolling_max(period).alias("high_n"))
53-
low_n = data.select(pl.col(low_column).rolling_min(period).alias("low_n"))
52+
high_n = data.select(pl.col(high_column).rolling_max(period)
53+
.alias("high_n"))
54+
low_n = data.select(pl.col(low_column).rolling_min(period)
55+
.alias("low_n"))
5456

5557
data = data.with_columns([
5658
high_n["high_n"],

0 commit comments

Comments
 (0)