Support pandas' iloc indexer (#191)

This commit is contained in:
Aymeric Roucher 2025-01-14 19:27:07 +01:00 committed by GitHub
parent 77f656c80d
commit ce1cd6d906
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 0 deletions

View File

@ -685,6 +685,9 @@ def evaluate_subscript(
if isinstance(value, pd.core.indexing._LocIndexer): if isinstance(value, pd.core.indexing._LocIndexer):
parent_object = value.obj parent_object = value.obj
return parent_object.loc[index] return parent_object.loc[index]
if isinstance(value, pd.core.indexing._iLocIndexer):
parent_object = value.obj
return parent_object.iloc[index]
if isinstance(value, (pd.DataFrame, pd.Series, np.ndarray)): if isinstance(value, (pd.DataFrame, pd.Series, np.ndarray)):
return value[index] return value[index]
elif isinstance(value, pd.core.groupby.generic.DataFrameGroupBy): elif isinstance(value, pd.core.groupby.generic.DataFrameGroupBy):

View File

@ -808,6 +808,7 @@ filtered_df = df.loc[df['AtomicNumber'].isin([104])]
) )
assert np.array_equal(result.values[0], [104, 1]) assert np.array_equal(result.values[0], [104, 1])
# Test groupby
code = """import pandas as pd code = """import pandas as pd
data = pd.DataFrame.from_dict([ data = pd.DataFrame.from_dict([
{"Pclass": 1, "Survived": 1}, {"Pclass": 1, "Survived": 1},
@ -821,6 +822,21 @@ survival_rate_by_class = data.groupby('Pclass')['Survived'].mean()
) )
assert result.values[1] == 0.5 assert result.values[1] == 0.5
# Test loc and iloc
code = """import pandas as pd
data = pd.DataFrame.from_dict([
{"Pclass": 1, "Survived": 1},
{"Pclass": 2, "Survived": 0},
{"Pclass": 2, "Survived": 1}
])
survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean()
survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean()
survival_rate_sorted = data.sort_values(by='Survived', ascending=False).iloc[0]
"""
result, _ = evaluate_python_code(
code, {}, state={}, authorized_imports=["pandas"]
)
def test_starred(self): def test_starred(self):
code = """ code = """
from math import radians, sin, cos, sqrt, atan2 from math import radians, sin, cos, sqrt, atan2