Browse Source

group
@debug_view

ignatz 4 năm trước cách đây
mục cha
commit
2e43995aa9

+ 1 - 1
.idea/misc.xml

@@ -1,4 +1,4 @@
 <?xml version="1.0" encoding="UTF-8"?>
 <project version="4">
-  <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9" project-jdk-type="Python SDK" />
+  <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (st-cloud)" project-jdk-type="Python SDK" />
 </project>

+ 4 - 2
.idea/st_cloud.iml

@@ -13,8 +13,10 @@
     </facet>
   </component>
   <component name="NewModuleRootManager">
-    <content url="file://$MODULE_DIR$" />
-    <orderEntry type="inheritedJdk" />
+    <content url="file://$MODULE_DIR$">
+      <excludeFolder url="file://$MODULE_DIR$/venv" />
+    </content>
+    <orderEntry type="jdk" jdkName="Python 3.9 (st-cloud)" jdkType="Python SDK" />
     <orderEntry type="sourceFolder" forTests="false" />
   </component>
   <component name="TemplatesService">

+ 1 - 0
3

@@ -0,0 +1 @@
+fadfasdf

+ 1 - 0
7

@@ -0,0 +1 @@
+fadfasdf

+ 0 - 15
_user/admin.py

@@ -1,15 +0,0 @@
-from django.contrib import admin
-from .models import User, LoginToken
-
-
-# Register your models here.
-class UserAdmin(admin.ModelAdmin):
-    list_display = ["username", "password", "email"]
-
-
-class LoginTokenAdmin(admin.ModelAdmin):
-    list_display = ["_user", "token"]
-
-
-admin.site.register(User, UserAdmin)
-admin.site.register(LoginToken, LoginTokenAdmin)

+ 0 - 5
_user/apps.py

@@ -1,5 +0,0 @@
-from django.apps import AppConfig
-
-
-class UserConfig(AppConfig):
-    name = 'user'

+ 0 - 45
_user/decorators.py

@@ -1,45 +0,0 @@
-from functools import wraps
-from urllib.parse import urlparse
-
-from django.conf import settings
-from django.shortcuts import resolve_url
-from django.http import JsonResponse
-
-from .models import User, LoginToken
-
-
-def user_passes_test(test_func, error):
-    def decorator(view_func):
-        @wraps(view_func)
-        def _wrapped_view(request, *args, **kwargs):
-            if test_func(request):
-                return view_func(request, *args, **kwargs)
-            return JsonResponse({'code': 401, 'error': error}, status=401)
-
-        return _wrapped_view
-
-    return decorator
-
-
-def login_required(function=None, error='error'):
-    """
-    Decorator for views that checks that the _user is logged in, redirecting
-    to the log-in page if necessary.
-    """
-
-    def is_login(request):
-        username = request.data.get('username', '')
-        token = request.data.get('token', '')
-        try:
-            user = User.objects.get(username='username')
-            if user.check_token(token):
-                user.tokens.get(token=token)
-                return True
-        except:
-            return False
-        return False
-
-    actual_decorator = user_passes_test(is_login, '请登录')
-    if function:
-        return actual_decorator(function)
-    return actual_decorator

+ 0 - 86
_user/models.py

@@ -1,86 +0,0 @@
-import unicodedata
-
-from datetime import datetime, time
-from django.core.mail import send_mail
-from django.db import models
-from django.utils.translation import gettext_lazy as _
-from django.utils.http import base36_to_int, int_to_base36
-from django.conf import settings
-from django.utils.crypto import constant_time_compare, salted_hmac
-from .validators import ASCIIUsernameValidator
-
-
-class User(models.Model):
-    username = models.CharField(
-        _('username'),
-        max_length=25,
-        unique=True,
-        help_text=_('Required. 25 characters or fewer. Letters, digits and _ only.'),
-        validators=[ASCIIUsernameValidator()],
-        error_messages={
-            'unique': _("A _user with that username already exists."),
-        },
-    )
-    password = models.CharField(_('password'), max_length=128)
-    last_login = models.DateTimeField(_('last login'), blank=True, null=True)
-    email = models.EmailField(_('email address'), unique=True)
-
-    class Meta:
-        db_table = '_user'
-        verbose_name = verbose_name_plural = '用户信息表'
-
-    def set_password(self, password):
-        # TODO: 密码强度检验,密码hash存储
-        self.password = password
-
-    def send_email(self, subject, message, from_email=None, **kwargs):
-        send_mail(subject, message, from_email, [self.email], **kwargs)
-
-    def make_token(self):
-        return self._make_token(_timestamp())
-
-    def check_token(self, token):
-        if not token:
-            return False
-        try:
-            ts_b36, hash_str = token.split('-')
-        except ValueError:
-            return False
-
-        try:
-            ts = base36_to_int(ts_b36)
-        except ValueError:
-            return False
-
-        if self._make_token(ts) != token:
-            return False
-
-        timestamp = _timestamp()
-        if (timestamp - ts) > settings.PASSWORD_RESET_TIMEOUT:
-            return False
-
-        return True
-
-    def _make_token(self, timestamp):
-        ts_b36 = int_to_base36(timestamp)
-        salt = settings.SALT
-        value = self._make_hash_value(timestamp)
-        secret = settings.SECRET_KEY
-        hash_str = salted_hmac(
-            salt, value, secret=secret, algorithm='sha256'
-        ).hexdigest()[::2]
-        token = "%s-%s" % (ts_b36, hash_str)
-        return token
-
-    def _make_hash_value(self, timestamp):
-        return f'{self.pk}{self.password}{timestamp}{self.email}'
-
-
-def _timestamp():
-    dt = datetime.now()
-    return int((dt - datetime(2001, 1, 1)).total_seconds())
-
-
-class LoginToken(models.Model):
-    user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='tokens')
-    token = models.CharField(max_length=256)

+ 0 - 203
_user/password_validation.py

@@ -1,203 +0,0 @@
-import functools
-import gzip
-import re
-from difflib import SequenceMatcher
-from pathlib import Path
-
-from django.conf import settings
-from django.core.exceptions import (
-    FieldDoesNotExist, ImproperlyConfigured, ValidationError,
-)
-from django.utils.functional import lazy
-from django.utils.html import format_html, format_html_join
-from django.utils.module_loading import import_string
-from django.utils.translation import gettext as _, ngettext
-
-
-@functools.lru_cache(maxsize=None)
-def get_default_password_validators():
-    return get_password_validators(settings.AUTH_PASSWORD_VALIDATORS)
-
-
-def get_password_validators(validator_config):
-    validators = []
-    for validator in validator_config:
-        try:
-            klass = import_string(validator['NAME'])
-        except ImportError:
-            msg = "The module in NAME could not be imported: %s. Check your AUTH_PASSWORD_VALIDATORS setting."
-            raise ImproperlyConfigured(msg % validator['NAME'])
-        validators.append(klass(**validator.get('OPTIONS', {})))
-
-    return validators
-
-
-def validate_password(password, user=None, password_validators=None):
-    """
-    Validate whether the password meets all validator requirements.
-
-    If the password is valid, return ``None``.
-    If the password is invalid, raise ValidationError with all error messages.
-    """
-    errors = []
-    if password_validators is None:
-        password_validators = get_default_password_validators()
-    for validator in password_validators:
-        try:
-            validator.validate(password, user)
-        except ValidationError as error:
-            errors.append(error)
-    if errors:
-        raise ValidationError(errors)
-
-
-def password_changed(password, user=None, password_validators=None):
-    """
-    Inform all validators that have implemented a password_changed() method
-    that the password has been changed.
-    """
-    if password_validators is None:
-        password_validators = get_default_password_validators()
-    for validator in password_validators:
-        password_changed = getattr(validator, 'password_changed', lambda *a: None)
-        password_changed(password, user)
-
-
-def password_validators_help_texts(password_validators=None):
-    """
-    Return a list of all help texts of all configured validators.
-    """
-    help_texts = []
-    if password_validators is None:
-        password_validators = get_default_password_validators()
-    for validator in password_validators:
-        help_texts.append(validator.get_help_text())
-    return help_texts
-
-
-def _password_validators_help_text_html(password_validators=None):
-    """
-    Return an HTML string with all help texts of all configured validators
-    in an <ul>.
-    """
-    help_texts = password_validators_help_texts(password_validators)
-    help_items = format_html_join('', '<li>{}</li>', ((help_text,) for help_text in help_texts))
-    return format_html('<ul>{}</ul>', help_items) if help_items else ''
-
-
-password_validators_help_text_html = lazy(_password_validators_help_text_html, str)
-
-
-class MinimumLengthValidator:
-    """
-    Validate whether the password is of a minimum length.
-    """
-    def __init__(self, min_length=8):
-        self.min_length = min_length
-
-    def validate(self, password, user=None):
-        if len(password) < self.min_length:
-            raise ValidationError(
-                ngettext(
-                    "This password is too short. It must contain at least %(min_length)d character.",
-                    "This password is too short. It must contain at least %(min_length)d characters.",
-                    self.min_length
-                ),
-                code='password_too_short',
-                params={'min_length': self.min_length},
-            )
-
-    def get_help_text(self):
-        return ngettext(
-            "Your password must contain at least %(min_length)d character.",
-            "Your password must contain at least %(min_length)d characters.",
-            self.min_length
-        ) % {'min_length': self.min_length}
-
-
-class UserAttributeSimilarityValidator:
-    """
-    Validate whether the password is sufficiently different from the _user's
-    attributes.
-
-    If no specific attributes are provided, look at a sensible list of
-    defaults. Attributes that don't exist are ignored. Comparison is made to
-    not only the full attribute value, but also its components, so that, for
-    example, a password is validated against either part of an email address,
-    as well as the full address.
-    """
-    DEFAULT_USER_ATTRIBUTES = ('username', 'first_name', 'last_name', 'email')
-
-    def __init__(self, user_attributes=DEFAULT_USER_ATTRIBUTES, max_similarity=0.7):
-        self.user_attributes = user_attributes
-        self.max_similarity = max_similarity
-
-    def validate(self, password, user=None):
-        if not user:
-            return
-
-        for attribute_name in self.user_attributes:
-            value = getattr(user, attribute_name, None)
-            if not value or not isinstance(value, str):
-                continue
-            value_parts = re.split(r'\W+', value) + [value]
-            for value_part in value_parts:
-                if SequenceMatcher(a=password.lower(), b=value_part.lower()).quick_ratio() >= self.max_similarity:
-                    try:
-                        verbose_name = str(user._meta.get_field(attribute_name).verbose_name)
-                    except FieldDoesNotExist:
-                        verbose_name = attribute_name
-                    raise ValidationError(
-                        _("The password is too similar to the %(verbose_name)s."),
-                        code='password_too_similar',
-                        params={'verbose_name': verbose_name},
-                    )
-
-    def get_help_text(self):
-        return _('Your password can’t be too similar to your other personal information.')
-
-
-class CommonPasswordValidator:
-    """
-    Validate whether the password is a common password.
-
-    The password is rejected if it occurs in a provided list of passwords,
-    which may be gzipped. The list Django ships with contains 20000 common
-    passwords (lowercased and deduplicated), created by Royce Williams:
-    https://gist.github.com/roycewilliams/281ce539915a947a23db17137d91aeb7
-    The password list must be lowercased to match the comparison in validate().
-    """
-    DEFAULT_PASSWORD_LIST_PATH = Path(__file__).resolve().parent / 'common-passwords.txt.gz'
-
-    def __init__(self, password_list_path=DEFAULT_PASSWORD_LIST_PATH):
-        try:
-            with gzip.open(password_list_path, 'rt', encoding='utf-8') as f:
-                self.passwords = {x.strip() for x in f}
-        except OSError:
-            with open(password_list_path) as f:
-                self.passwords = {x.strip() for x in f}
-
-    def validate(self, password, user=None):
-        if password.lower().strip() in self.passwords:
-            raise ValidationError(
-                _("This password is too common."),
-                code='password_too_common',
-            )
-
-    def get_help_text(self):
-        return _('Your password can’t be a commonly used password.')
-
-
-class NumericPasswordValidator:
-    """
-    Validate whether the password is alphanumeric.
-    """
-    def validate(self, password, user=None):
-        if password.isdigit():
-            raise ValidationError(
-                _("This password is entirely numeric."),
-                code='password_entirely_numeric',
-            )
-
-    def get_help_text(self):
-        return _('Your password can’t be entirely numeric.')

+ 0 - 8
_user/serializers.py

@@ -1,8 +0,0 @@
-from rest_framework import serializers
-from .models import User, LoginToken
-
-
-class UserSerializer(serializers.ModelSerializer):
-    class Meta:
-        model = User
-        fields = ('username', 'password', 'email')

+ 0 - 14
_user/urls.py

@@ -1,14 +0,0 @@
-# The views used below are normally mapped in the AdminSite instance.
-# This URLs file is used to provide a reliable view deployment for test purposes.
-# It is also provided as a convenience to those who want to deploy these URLs
-# elsewhere.
-
-from django.conf.urls import url
-from django.urls import path
-from . import views
-from rest_framework.urlpatterns import format_suffix_patterns
-
-urlpatterns = [
-]
-
-urlpatterns = format_suffix_patterns(urlpatterns)

+ 0 - 25
_user/validators.py

@@ -1,25 +0,0 @@
-import re
-
-from django.core import validators
-from django.utils.deconstruct import deconstructible
-from django.utils.translation import gettext_lazy as _
-
-
-@deconstructible
-class ASCIIUsernameValidator(validators.RegexValidator):
-    regex = r'^[\w]+\Z'
-    message = _(
-        'Enter a valid username. This value may contain only English letters, '
-        'numbers, and _ characters.'
-    )
-    flags = re.ASCII
-
-
-@deconstructible
-class UnicodeUsernameValidator(validators.RegexValidator):
-    regex = r'^[\w]+\Z'
-    message = _(
-        'Enter a valid username. This value may contain only letters, '
-        'numbers, and _ characters.'
-    )
-    flags = 0

+ 0 - 123
_user/views.py

@@ -1,123 +0,0 @@
-from django.shortcuts import render
-
-# Create your views here.
-
-from datetime import datetime, time
-from .models import User, LoginToken
-from django.http import JsonResponse, HttpResponse
-from rest_framework.decorators import api_view
-from django.middleware.csrf import rotate_token
-
-
-def auth_with_username_or_email(username, password):
-    if '@' in username:
-        user = User.objects.get(email=username, password=password)
-    else:
-        user = User.objects.get(username=username, password=password)
-    return user
-
-
-@api_view(['POST'])
-def register(request):
-    username = request.data.get('username', '')
-    password = request.data.get('password', '')
-    email = request.data.get('email', '')
-
-    try:
-        User.objects.create(username=username, password=password, email=email)
-        print('注册成功')
-        return JsonResponse({'code': 200})
-    except Exception as e:
-        print(e)
-        return JsonResponse({'code': 303, 'error': str(e)}, status=303)
-
-
-@api_view(['POST'])
-def login(request):
-    username = request.data.get('username', '')
-    password = request.data.get('password', '')
-    token = request.data.get('token', '')
-
-    try:
-        user = auth_with_username_or_email(username, password)
-        print(user)
-    except Exception as e:
-        print(e)
-        print('用户名或密码错误')
-        return JsonResponse({'code': 303, 'error': '用户名或密码错误'}, status=303)
-
-    print(f'token = {token}')
-    if user.check_token(token):
-        try:
-            user_token = user.tokens.get(token=token)
-            print('已登录')
-            user_token.delete()
-            # return JsonResponse({'code': 303, 'msg': '已登录'}, status=303)
-        except Exception as e:
-            print('token无效')
-    else:
-        print('token已过期')
-
-    user.last_login = datetime.now()
-
-    new_token = user.make_token()
-    user_token = LoginToken()
-    user_token.user = user
-    user_token.token = new_token
-    user_token.save()
-
-    if hasattr(request, '_user'):
-        print('设置reqeust._user')
-        request.user = user
-
-    print('登录成功')
-    print(f'new_token = {new_token}')
-    return JsonResponse({'code': 200, 'token': new_token})
-
-
-@api_view(['POST'])
-def logout(request):
-    username = request.data.get('username', '')
-    token = request.data.get('token', '')
-    try:
-        user = User.objects.get(username=username)
-        try:
-            user_token = user.tokens.get(token=token)
-            user_token.delete()
-        except Exception as e:
-            print(e)
-            print('token无效')
-        return JsonResponse({'code': 200})
-    except Exception as e:
-        print(e)
-        return JsonResponse({'code': 303, 'error': str(e)}, status=303)
-
-
-@api_view(['POST'])
-def reset_password(request):
-    username = request.data.get('username', '')
-    password = request.data.get('password', '')
-    try:
-        user = User.objects.get(username=username)
-        token = request.data.get('token')
-        if token:
-            print(f'token={token}')
-            if user.check_token(token):
-                # 重置密码
-                print("验证码有效")
-                user.password = password
-                user.save()
-                return JsonResponse({'code': 200})
-            else:
-                print("验证码无效")
-                return JsonResponse({'code': 303, 'error': '验证码错误'}, status=303)
-        else:
-            # 发送验证码
-            token = user.make_token()
-            print(f'')
-            print(f'发送验证码 email = {user.email} token = {token}')
-            user.send_email('ST网盘重置密码验证码', token)
-            return JsonResponse({'code': 200})
-    except Exception as e:
-        print(e)
-        return JsonResponse({'code': 303, 'error': str(e)}, status=303)

+ 3 - 1
account/decorators.py

@@ -14,7 +14,7 @@ def user_passes_test(test_func, error):
         def _wrapped_view(request, *args, **kwargs):
             if test_func(request):
                 return view_func(request, *args, **kwargs)
-            return JsonResponse({'code': 401, 'error': error}, status=401)
+            return JsonResponse({'code': 401, 'error': error})
         return _wrapped_view
     return decorator
 
@@ -35,6 +35,8 @@ def login_required(function=None, error='error'):
             if user.check_token(token):
                 user.tokens.get(token=token)
                 print('已登录')
+                if hasattr(request, 'user'):
+                    request.user = user
                 return True
         except:
             print('未登录')

+ 3 - 0
account/models.py

@@ -81,6 +81,9 @@ class User(models.Model):
 
 
 def get_user(request):
+    if hasattr(request, 'user') and isinstance(request.user, User):
+        print(f'get user from request.user, username={request.user.username}')
+        return request.user
     username = request.POST.get('username', '')
     token = request.POST.get('token', '')
     try:

BIN
db.sqlite3


+ 19 - 7
file/views.py

@@ -36,16 +36,20 @@ def upload_file(request):
         file_name = file_obj.name
         folder_id = request.POST.get('folder_id')
         try:
-            folder = user.folders.get(folder_id=folder_id)
+            folder = Folder.objects.get(folder_id=folder_id)
         except:
             print('文件夹不存在')
             return JsonResponse({'code': 402, 'error': '文件夹不存在'})
+        if not folder.check_permission(user=user):
+            print('没有上传文件的权限')
+            return JsonResponse({'code': 404, 'error': '没有上传文件的权限'})
         file = File.objects.create(file_name=file_name,
                                    folder=folder,
                                    update_time=update_time,
                                    file_size=file_size,
                                    file_type=file_type,
-                                   owner=user)
+                                   owner=user,
+                                   group=folder.group)
         # TODO: 文件hash
         try:
             file_dir = BASE_DIR + '/' + str(file.file_id)
@@ -69,10 +73,13 @@ def download_file(request):
         user = get_user(request)
         file_id = request.POST.get('file_id')
         try:
-            file = user.files.get(file_id=file_id)
+            file = File.objects.get(file_id=file_id)
         except:
             print('文件不存在')
             return JsonResponse({'code': 401, 'error': '文件不存在'})
+        if not file.folder.check_permission(user=user):
+            print('没有下载文件的权限')
+            return JsonResponse({'code': 404, 'error': '没有下载文件的权限'})
         file_name = file.file_name
         file_dir = BASE_DIR + '/' + str(file.file_id)
         file = open(file_dir, 'rb')
@@ -93,15 +100,20 @@ def delete_file(request):
         user = get_user(request)
         file_id = data.get('file_id')
         try:
-            file = user.files.get(file_id=file_id)
+            file = File.objects.get(file_id=file_id)
         except:
             print('文件不存在')
             return JsonResponse({'code': 401, 'error': '文件不存在'})
-        file.delete()
+        if not file.folder.check_permission(user=user) or (
+                file.owner != user and (not file.group or file.group.creator != user)):
+            print('没有删除文件的权限')
+            return JsonResponse({'code': 404, 'error': '没有删除文件的权限'})
         try:
             os.remove(BASE_DIR + '/' + file_id)
-        except Exception as e:
-            print(e)
+        except:
+            print('文件删除失败')
+            return JsonResponse({'code': 500, 'error': '文件删除失败'})
+        file.delete()
         return JsonResponse({'code': 200})
     elif request.method == 'GET' and DEBUG:
         return render(request, 'delete_file.html')

+ 1 - 1
folder/admin.py

@@ -4,7 +4,7 @@ from .models import Folder
 
 # Register your models here.
 class FolderAdmin(admin.ModelAdmin):
-    list_display = ["folder_id", "folder_name", "father_folder", "owner"]
+    list_display = ["folder_id", "folder_name", "father_folder", "owner", "group"]
 
 
 admin.site.register(Folder, FolderAdmin)

+ 5 - 2
folder/models.py

@@ -16,15 +16,18 @@ class Folder(models.Model):
     # 文件夹名
     folder_name = models.CharField(max_length=50, blank=False, default='root')
     # 父节点
-    father_folder = models.ForeignKey('self', blank=True, on_delete=models.SET_NULL, null=True,
+    father_folder = models.ForeignKey('self', blank=True, on_delete=models.CASCADE, null=True,
                                       related_name='children_folders')
     # 所有者
-    owner = models.ForeignKey(User, on_delete=models.CASCADE, related_name='folders')
+    owner = models.ForeignKey(User, on_delete=models.DO_NOTHING, related_name='folders')
     group = models.ForeignKey(Group, on_delete=models.DO_NOTHING, related_name='folders', null=True)
 
     def to_json(self):
         return {'folder_id': self.folder_id, 'folder_name': self.folder_name, 'father_folder_id': self.father_folder_id}
 
+    def check_permission(self, user:User):
+        return self.owner == user or user.joined_groups.filter(group_id=self.group_id).count() > 0
+
     def __str__(self):
         return str(self.folder_id)
 

+ 2 - 2
folder/views.py

@@ -8,6 +8,7 @@ from .models import Folder
 from file.models import File
 import json
 from account.models import get_user
+from utils.decorators import debug_view
 
 # Create your views here.
 DEBUG = 1
@@ -17,7 +18,6 @@ DEBUG = 1
 def get_root_folder(request):
     # 获取根目录
     if request.method == 'POST':
-        # 获取所有根文件夹id
         user = get_user(request)
         return JsonResponse({'code': 200, 'root_folder_id': user.get_root_folder().folder_id})
     elif request.method == 'GET' and DEBUG:
@@ -31,7 +31,7 @@ def get_root_folder(request):
 def folder_list(request):
     if request.method == 'POST':
         data = request.POST
-        folder_id = data['folder_id']
+        folder_id = data.get('folder_id')
         user = get_user(request)
         try:
             folder = user.folders.get(folder_id=folder_id)

+ 18 - 0
group/migrations/0004_rename_member_group_members.py

@@ -0,0 +1,18 @@
+# Generated by Django 3.2.7 on 2021-09-10 04:21
+
+from django.db import migrations
+
+
+class Migration(migrations.Migration):
+
+    dependencies = [
+        ('group', '0003_auto_20210910_0010'),
+    ]
+
+    operations = [
+        migrations.RenameField(
+            model_name='group',
+            old_name='member',
+            new_name='members',
+        ),
+    ]

+ 21 - 1
group/models.py

@@ -2,6 +2,11 @@ from django.db import models
 
 from account.models import User
 
+# 引入内置信号
+from django.db.models.signals import post_save
+# 引入信号接收器的装饰器
+from django.dispatch import receiver
+
 
 class Group(models.Model):
     # 群id
@@ -11,5 +16,20 @@ class Group(models.Model):
     # 群管理员
     creator = models.ForeignKey(User, on_delete=models.CASCADE, related_name="my_groups")
     # 群成员(多对多)
-    member = models.ManyToManyField(User, related_name="joined_groups")
+    members = models.ManyToManyField(User, related_name="joined_groups")
+
+    def to_json(self):
+        return {'group_id': self.group_id, 'group_name': self.group_name}
+
+    def get_root_folder(self):
+        return self.folders.get(father_folder=None)
+
+    def __str__(self):
+        return str(self.group_id)
+
 
+# 信号接收函数,每当新建Group实例的时候自动调用
+@receiver(post_save, sender=Group)
+def add_creator_to_group_members(sender, instance, created, **kwargs):
+    if created and instance.creator not in instance.members:
+        instance.members.add(instance.creator)

+ 3 - 2
group/urls.py

@@ -3,8 +3,9 @@ from django.urls import path
 from . import views
 # Create your views here.
 urlpatterns = [
-    path('add_group/', views.add_group, name='add_group'),
-    path('leave_group/', views.leave_group, name='leave_group'),
+    path('get_group_root_folder/', views.get_group_root_folder, name='get_group_root_folder'),
+    path('join_group/', views.join_group, name='join_group'),
+    path('quit_group/', views.quit_group, name='quit_group'),
     path('group_list/', views.group_list, name='group_list'),
     path('create_group/', views.create_group, name='create_group')
 ]

+ 64 - 93
group/views.py

@@ -1,115 +1,86 @@
-from django.http import HttpResponse
-from django.shortcuts import render
 from account.decorators import login_required
 from .models import Group
-from account.models import User
-from folder.models import Folder
-import json
-import random
-import string
+from account.models import User, get_user
+from utils.decorators import debug_view
+from utils.http import make_json_response
+from django.views.decorators.http import require_GET, require_POST
 
 # Create your views here.
-DEBUG = 1
 
 
-# 首先是 加入/退出 群组,客户端发送申请,把该用户直接 加入/删除 该群组的对象
+@debug_view('get_group_root_folder.html')
+@require_POST
 @login_required
-def add_group(request):
-    if request.method == 'POST':
-        data = request.POST
-        # 查找是否有此人
-        user = User.objects.filter(username=data['username']).get()
-        if user:
-            # 查找此人是否已经在群组里
-            group = Group.objects.filter(group_id=data['group_id']).get()
-            if group.member.filter(username=user.username).count() > 0:
-                # 此人已在群组里
-                return HttpResponse(status=421)
-            else:
-                group.member.add(user)
-                group.save()
-                return HttpResponse(status=200)
-        else:
-            return HttpResponse(status=422)
-    elif request.method == 'GET':
-        if DEBUG:
-            return render(request, 'add_group.html')
-    else:
-        return HttpResponse(status=400)
+def get_group_root_folder(request):
+    user = get_user(request)
+    data = request.POST
+    group_id = data.get('group_id')
+    try:
+        group = user.joined_groups.get(group_id=group_id)
+    except:
+        return make_json_response(code=401, error='群不存在')
+    root_folder = group.get_root_folder()
+    return make_json_response(root_folder_id=root_folder.folder_id)
 
 
+@debug_view('join_group.html')
+@require_POST
 @login_required
-def create_group(request):
-    if request.method == 'POST':
-        data = request.POST
-        # 查找是否有此人
-        username = data.get('username', '')
-        user = User.objects.filter(username=username)
-        if user:
-            group_id_random = ''.join(random.sample(string.digits, 8))
-            folder_1 = Folder.objects.create(folder_id=group_id_random,
-                                             folder_name=data['group_name'],
-                                             father_folder=None)
-            folder_1.save()
-            group_1 = Group.objects.create(group_id=group_id_random,
-                                           group_name=data['group_name'],
-                                           creator=username,
-                                           folder=folder_1,)
-            group_1.member.set(user)
-            group_1.save()
-            return HttpResponse(status=200)
-        else:
-            return HttpResponse(status=422)
-    elif request.method == 'GET':
-        if DEBUG:
-            return render(request, 'create_group.html')
-    else:
-        return HttpResponse(status=400)
+def join_group(request):
+    user = get_user(request)
+    data = request.POST
+    group_id = data.get('group_id')
+    if user.joined_groups.filter(group_id=group_id).exists():
+        return make_json_response(code=402, error='已在群内')
+    try:
+        group = Group.objects.get(group_id=group_id)
+    except:
+        return make_json_response(code=401, error='群不存在')
+    group.members.add(user)
+    group.save()
+    return make_json_response()
+
 
+@debug_view('create_group.html')
+@require_POST
+@login_required
+def create_group(request):
+    user = get_user(request)
+    data = request.POST
+    group_name = data.get('group_name')
+    try:
+        Group.objects.create(group_name=group_name, creator=user)
+    except:
+        return make_json_response(code=500, error='新建群失败')
+    return make_json_response()
 
-# value = ''.join(random.sample(string.ascii_letters + string.digits, 8))
 
+@debug_view('quit_group.html')
+@require_POST
 @login_required
-def leave_group(request):
-    if request.method == 'POST':
-        data = request.POST
-        # 查找是否有此人
-        user = User.objects.filter(username=data['username']).get()
-        if user:
-            # 查找此人是否已经在群组里
-            group = Group.objects.filter(group_id=data['group_id']).get()
-            if group.member.filter(username=user.username).count() == 0:
-                # 此人不在群组里
-                return HttpResponse(status=421)
-            else:
-                group.member.remove(user)
-                group.save()
-                return HttpResponse(status=200)
-        else:
-            return HttpResponse(status=422)
-    elif request.method == 'GET':
-        if DEBUG:
-            return render(request, 'leave_group.html')
-    else:
-        return HttpResponse(status=400)
+def quit_group(request):
+    user = get_user(request)
+    data = request.POST
+    group_id = data.get('group_id')
+    try:
+        group = user.joined_groups.get(group_id=group_id)
+    except:
+        return make_json_response(code=401, error='群不存在')
+    if group.creator == user:
+        return make_json_response(code=402, error='群主不可退群')
+    group.members.remove(user)
+    group.save()
+    return make_json_response()
 
 
 # 获取你所在的所有群组
+@debug_view('group_list.html')
+@require_POST
 @login_required
 def group_list(request):
-    if request.method == 'POST':
-        data = request.POST
-        # 获取群组
-        groups = Group.objects.filter(member__username=data['username'])
-        response = []
-        for i in groups:
-            response.append(i.group_name)
-        return HttpResponse(json.dumps(response), status=200)
-    elif request.method == 'GET':
-        if DEBUG:
-            return render(request, 'group_list.html')
-    else:
-        return HttpResponse(status=400)
+    user = get_user(request)
+    _list = list(map(lambda g: g.to_json(), user.joined_groups.all()))
+    return make_json_response(group_list=_list)
 
 
 # 删除群组

+ 1 - 1
st_cloud/settings.py

@@ -23,7 +23,7 @@ BASE_DIR = Path(__file__).resolve().parent.parent
 SECRET_KEY = 'django-insecure-h1r^p(6-s&@7u!q(sv%_@97fxv(ikbi7d9p#i9+-o_3&pbpw(j'
 SALT = 'sa0v-038auwmd-r0awvy4-0y4vs9mdy9-aby09384vy-amr9tv8ybsva9v4y'
 
-# SECURITY WARNING: don't run with debug turned on in production!
+# SECURITY WARNING: don't run with utils turned on in production!
 DEBUG = True
 
 ALLOWED_HOSTS = []

+ 24 - 0
templates/get_group_root_folder.html

@@ -0,0 +1,24 @@
+<!DOCTYPE html>
+<html lang="zh-cn">
+    <div>
+        <form method="post" action=".">
+            {% csrf_token %}
+            <!-- 用户名 -->
+            <div>
+                <label for="username">用户名</label>
+                <input type="text" id="username" name="username">
+            </div>
+            <div>
+                <label for="token">token</label>
+                <input type="text" id="token" name="token">
+            </div>
+            <!-- 群号 -->
+            <div>
+                <label for="group_id">群号</label>
+                <input type="text" id="group_id" name="group_id">
+            </div>
+            <!-- 提交按钮 -->
+            <button type="submit">提交</button>
+        </form>
+    </div>
+</html>

+ 0 - 0
templates/add_group.html → templates/join_group.html


+ 0 - 0
templates/leave_group.html → templates/quit_group.html


+ 0 - 0
utils/__init__.py


+ 18 - 0
utils/decorators.py

@@ -0,0 +1,18 @@
+from functools import wraps
+from urllib.parse import urlparse
+
+from django.conf import settings
+from django.shortcuts import render
+from django.http import JsonResponse
+
+DEBUG = settings.DEBUG
+
+
+def debug_view(template_name):
+    def decorator(view_func):
+        @wraps(view_func)
+        def _wrapped_view(request, *args, **kwargs):
+            return render(request, template_name) if DEBUG and request.method == 'GET' \
+                else view_func(request, *args, **kwargs)
+        return _wrapped_view
+    return decorator

+ 10 - 0
utils/http.py

@@ -0,0 +1,10 @@
+from django.http import HttpResponse, JsonResponse
+
+
+def make_json_response(**kwargs):
+    if not kwargs.get('code'):
+        kwargs['code'] = 200
+    error = kwargs.get('error')
+    if error:
+        print(error)
+    return JsonResponse(kwargs, json_dumps_params={"ensure_ascii": False})