Added database tests along with workflows
This commit is contained in:
parent
b0e82fdefc
commit
3dcad9badf
6 changed files with 747 additions and 380 deletions
27
.github/workflows/black.yaml
vendored
Normal file
27
.github/workflows/black.yaml
vendored
Normal file
|
@ -0,0 +1,27 @@
|
|||
name: Black
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master, stable ]
|
||||
pull_request:
|
||||
branches: [ master, stable ]
|
||||
|
||||
jobs:
|
||||
black:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.8
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install black
|
||||
|
||||
- name: Black check
|
||||
run: |
|
||||
black --check butterrobot
|
32
.github/workflows/pytest.yaml
vendored
Normal file
32
.github/workflows/pytest.yaml
vendored
Normal file
|
@ -0,0 +1,32 @@
|
|||
name: Pytest
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master, stable ]
|
||||
pull_request:
|
||||
branches: [ master, stable ]
|
||||
|
||||
jobs:
|
||||
pytest:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [3.8]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install --upgrade pip poetry
|
||||
poetry install
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
ls
|
||||
poetry run pytest --cov=butterrobot
|
|
@ -1,5 +1,6 @@
|
|||
import hashlib
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
import dataset
|
||||
|
||||
|
@ -19,14 +20,40 @@ class Query:
|
|||
|
||||
@classmethod
|
||||
def all(cls):
|
||||
for row in cls._table.all():
|
||||
yield cls._obj(**row)
|
||||
"""
|
||||
Iterate over all rows on a table.
|
||||
"""
|
||||
for row in db[cls.tablename].all():
|
||||
yield cls.obj(**row)
|
||||
|
||||
@classmethod
|
||||
def exists(cls, *args, **kwargs):
|
||||
def get(cls, **kwargs) -> 'class':
|
||||
"""
|
||||
Returns the object representation of an specific row in a table.
|
||||
Allows retrieving object by multiple columns.
|
||||
Raises `NotFound` error if query return no results.
|
||||
"""
|
||||
row = db[cls.tablename].find_one()
|
||||
if not row:
|
||||
raise cls.NotFound
|
||||
return cls.obj(**row)
|
||||
|
||||
@classmethod
|
||||
def create(cls, **kwargs):
|
||||
"""
|
||||
Creates a new row in the table with the provided arguments.
|
||||
Returns the row_id
|
||||
TODO: Return obj?
|
||||
"""
|
||||
return db[cls.tablename].insert(kwargs)
|
||||
|
||||
@classmethod
|
||||
def exists(cls, **kwargs) -> bool:
|
||||
"""
|
||||
Check for the existence of a row with the provided columns.
|
||||
"""
|
||||
try:
|
||||
# Using only *args since those are supposed to be mandatory
|
||||
cls.get(*args)
|
||||
cls.get(**kwargs)
|
||||
except cls.NotFound:
|
||||
return False
|
||||
return True
|
||||
|
@ -34,27 +61,16 @@ class Query:
|
|||
@classmethod
|
||||
def update(cls, row_id, **fields):
|
||||
fields.update({"id": row_id})
|
||||
return cls._table.update(fields, ("id", ))
|
||||
return db[cls.tablename].update(fields, ("id", ))
|
||||
|
||||
@classmethod
|
||||
def get(cls, _id):
|
||||
row = cls._table.find_one(id=_id)
|
||||
if not row:
|
||||
raise cls.NotFound
|
||||
return cls._obj(**row)
|
||||
def delete(cls, id):
|
||||
return db[cls.tablename].delete(id=id)
|
||||
|
||||
@classmethod
|
||||
def update(cls, _id, **fields):
|
||||
fields.update({"id": _id})
|
||||
return cls._table.update(fields, ("id"))
|
||||
|
||||
@classmethod
|
||||
def delete(cls, _id):
|
||||
cls._table.delete(id=_id)
|
||||
|
||||
class UserQuery(Query):
|
||||
_table = db["users"]
|
||||
_obj = User
|
||||
tablename = "users"
|
||||
obj = User
|
||||
|
||||
@classmethod
|
||||
def _hash_password(cls, password):
|
||||
|
@ -63,32 +79,23 @@ class UserQuery(Query):
|
|||
).hex()
|
||||
|
||||
@classmethod
|
||||
def check_credentials(cls, username, password):
|
||||
user = cls._table.find_one(username=username)
|
||||
def check_credentials(cls, username, password) -> Union[User, 'False']:
|
||||
user = db[cls.tablename].find_one(username=username)
|
||||
if user:
|
||||
hash_password = cls._hash_password(password)
|
||||
if user["password"] == hash_password:
|
||||
return cls._obj(**user)
|
||||
return cls.obj(**user)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def create(cls, username, password):
|
||||
hash_password = cls._hash_password(password)
|
||||
cls._table.insert({"username": username, "password": hash_password})
|
||||
|
||||
@classmethod
|
||||
def delete(cls, username):
|
||||
return cls._table.delete(username=username)
|
||||
|
||||
@classmethod
|
||||
def update(cls, username, **fields):
|
||||
fields.update({"username": username})
|
||||
return cls._table.update(fields, ("username",))
|
||||
def create(cls, **kwargs):
|
||||
kwargs["password"] = cls._hash_password(kwargs["password"])
|
||||
super().create(**kwargs)
|
||||
|
||||
|
||||
class ChannelQuery(Query):
|
||||
_table = db["channels"]
|
||||
_obj = Channel
|
||||
tablename = "channels"
|
||||
obj = Channel
|
||||
|
||||
@classmethod
|
||||
def create(cls, platform, platform_channel_id, enabled=False, channel_raw={}):
|
||||
|
@ -98,8 +105,8 @@ class ChannelQuery(Query):
|
|||
"enabled": enabled,
|
||||
"channel_raw": channel_raw,
|
||||
}
|
||||
cls._table.insert(params)
|
||||
return cls._obj(**params)
|
||||
super().create(**params)
|
||||
return cls.obj(**params)
|
||||
|
||||
@classmethod
|
||||
def get(cls, _id):
|
||||
|
@ -110,7 +117,7 @@ class ChannelQuery(Query):
|
|||
|
||||
@classmethod
|
||||
def get_by_platform(cls, platform, platform_channel_id):
|
||||
result = cls._table.find_one(
|
||||
result = cls.tablename.find_one(
|
||||
platform=platform, platform_channel_id=platform_channel_id
|
||||
)
|
||||
if not result:
|
||||
|
@ -118,7 +125,7 @@ class ChannelQuery(Query):
|
|||
|
||||
plugins = ChannelPluginQuery.get_from_channel_id(result["id"])
|
||||
|
||||
return cls._obj(plugins={plugin.plugin_id: plugin for plugin in plugins}, **result)
|
||||
return cls.obj(plugins={plugin.plugin_id: plugin for plugin in plugins}, **result)
|
||||
|
||||
@classmethod
|
||||
def delete(cls, _id):
|
||||
|
@ -127,12 +134,12 @@ class ChannelQuery(Query):
|
|||
|
||||
|
||||
class ChannelPluginQuery(Query):
|
||||
_table = db["channel_plugin"]
|
||||
_obj = ChannelPlugin
|
||||
tablename = "channel_plugin"
|
||||
obj = ChannelPlugin
|
||||
|
||||
@classmethod
|
||||
def create(cls, channel_id, plugin_id, enabled=False, config={}):
|
||||
if cls.exists(channel_id, plugin_id):
|
||||
if cls.exists(id=channel_id, plugin_id=plugin_id):
|
||||
raise cls.Duplicated
|
||||
|
||||
params = {
|
||||
|
@ -141,25 +148,13 @@ class ChannelPluginQuery(Query):
|
|||
"enabled": enabled,
|
||||
"config": config,
|
||||
}
|
||||
obj_id = cls._table.insert(params)
|
||||
return cls._obj(id=obj_id, **params)
|
||||
|
||||
@classmethod
|
||||
def get(cls, channel_id, plugin_id):
|
||||
result = cls._table.find_one(channel_id=channel_id, plugin_id=plugin_id)
|
||||
if not result:
|
||||
raise cls.NotFound
|
||||
return cls._obj(**result)
|
||||
obj_id = super().create(**params)
|
||||
return cls.obj(id=obj_id, **params)
|
||||
|
||||
@classmethod
|
||||
def get_from_channel_id(cls, channel_id):
|
||||
yield from [cls._obj(**row) for row in cls._table.find(channel_id=channel_id)]
|
||||
|
||||
@classmethod
|
||||
def delete(cls, channel_plugin_id):
|
||||
return cls._table.delete(id=channel_plugin_id)
|
||||
yield from [cls.obj(**row) for row in cls.tablename.find(channel_id=channel_id)]
|
||||
|
||||
@classmethod
|
||||
def delete_by_channel(cls, channel_id):
|
||||
cls._table.delete(channel_id=channel_id)
|
||||
|
||||
cls.tablename.delete(channel_id=channel_id)
|
||||
|
|
848
poetry.lock
generated
848
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -27,6 +27,8 @@ flake8 = "^3.7.9"
|
|||
rope = "^0.16.0"
|
||||
isort = "^4.3.21"
|
||||
ipdb = "^0.13.2"
|
||||
pytest = "^6.1.2"
|
||||
pytest-cov = "^2.10.1"
|
||||
|
||||
[tool.poetry.plugins]
|
||||
[tool.poetry.plugins."butterrobot.plugins"]
|
||||
|
|
109
tests/test_db.py
Normal file
109
tests/test_db.py
Normal file
|
@ -0,0 +1,109 @@
|
|||
import os.path
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from unittest import mock
|
||||
|
||||
import dataset
|
||||
import pytest
|
||||
|
||||
from butterrobot import db
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummyItem:
|
||||
id: int
|
||||
foo: str
|
||||
|
||||
|
||||
class DummyQuery(db.Query):
|
||||
tablename = "dummy"
|
||||
obj = DummyItem
|
||||
|
||||
|
||||
class MockDatabase:
|
||||
def __init__(self):
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
|
||||
def __enter__(self):
|
||||
db_path = os.path.join(self.temp_dir.name, "db.sqlite")
|
||||
db.db = dataset.connect(f"sqlite:///{db_path}")
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.temp_dir.cleanup()
|
||||
|
||||
|
||||
def test_query_create_ok():
|
||||
with MockDatabase():
|
||||
assert DummyQuery.create(foo="bar")
|
||||
|
||||
|
||||
def test_query_delete_ok():
|
||||
with MockDatabase():
|
||||
item_id = DummyQuery.create(foo="bar")
|
||||
assert DummyQuery.delete(item_id)
|
||||
|
||||
|
||||
def test_query_exists_by_id_ok():
|
||||
with MockDatabase():
|
||||
assert not DummyQuery.exists(id=1)
|
||||
item_id = DummyQuery.create(foo="bar")
|
||||
assert DummyQuery.exists(id=item_id)
|
||||
|
||||
|
||||
def test_query_exists_by_attribute_ok():
|
||||
with MockDatabase():
|
||||
assert not DummyQuery.exists(id=1)
|
||||
item_id = DummyQuery.create(foo="bar")
|
||||
assert DummyQuery.exists(foo="bar")
|
||||
|
||||
|
||||
def test_query_get_ok():
|
||||
with MockDatabase():
|
||||
item_id = DummyQuery.create(foo="bar")
|
||||
item = DummyQuery.get(id=item_id)
|
||||
assert item.id
|
||||
|
||||
|
||||
def test_query_all_ok():
|
||||
with MockDatabase():
|
||||
assert len(list(DummyQuery.all())) == 0
|
||||
[DummyQuery.create(foo="bar") for i in range(0, 3)]
|
||||
assert len(list(DummyQuery.all())) == 3
|
||||
|
||||
|
||||
def test_update_ok():
|
||||
with MockDatabase():
|
||||
expected = "bar2"
|
||||
item_id = DummyQuery.create(foo="bar")
|
||||
assert DummyQuery.update(item_id, foo=expected)
|
||||
item = DummyQuery.get(id=item_id)
|
||||
assert item.foo == expected
|
||||
|
||||
|
||||
def test_create_user_sets_password_ok():
|
||||
password = "password"
|
||||
with MockDatabase():
|
||||
user_id = db.UserQuery.create(username="foo", password=password)
|
||||
user = db.UserQuery.get(id=user_id)
|
||||
assert user.password == db.UserQuery._hash_password(password)
|
||||
|
||||
|
||||
def test_user_check_credentials_ok():
|
||||
with MockDatabase():
|
||||
username = "foo"
|
||||
password = "bar"
|
||||
user_id = db.UserQuery.create(username=username, password=password)
|
||||
user = db.UserQuery.get(id=user_id)
|
||||
user = db.UserQuery.check_credentials(username, password)
|
||||
assert isinstance(user, db.UserQuery.obj)
|
||||
|
||||
|
||||
def test_user_check_credentials_ko():
|
||||
with MockDatabase():
|
||||
username = "foo"
|
||||
password = "bar"
|
||||
user_id = db.UserQuery.create(username=username, password=password)
|
||||
user = db.UserQuery.get(id=user_id)
|
||||
assert not db.UserQuery.check_credentials(username, "error")
|
||||
assert not db.UserQuery.check_credentials("error", password)
|
||||
assert not db.UserQuery.check_credentials("error", "error")
|
Loading…
Add table
Add a link
Reference in a new issue